Compare commits

...

29 Commits

Author SHA1 Message Date
Yota Hamada
4633850d2d
feat(all): storing output data (#1511)
Some checks are pending
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Waiting to run
Frontend CI / Build (push) Waiting to run
* **New Features**
* Added DAG run outputs API endpoint to retrieve execution outputs with
metadata
* New outputs visualization UI with search, filtering, and
copy-to-clipboard functionality
* Flexible output configuration supporting custom keys and omit options
  * Timeline view for DAG execution visualization

* **Improvements**
  * Automatic snake_case to camelCase conversion for output keys
  * Secret masking applied to collected outputs
* Metadata tracking includes DAG name, run ID, attempt ID, status, and
completion timestamp
2025-12-28 14:37:41 +09:00
Yota Hamada
140afbaffe
feat(al): API key management (#1509)
Some checks are pending
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Waiting to run
Frontend CI / Build (push) Waiting to run
* **New Features**
* Full API Key management: create, list, retrieve, update, revoke; keys
usable for authentication.

* **API**
* OpenAPI/schema and server routes added for API key CRUD with
comprehensive request/response/error variants.

* **Auth & Middleware**
* Service-level generation, validation, last-used tracking and
middleware path for API key auth.

* **Storage & Config**
* File-backed API key store with in-memory index and new config path for
API keys directory.

* **UI**
* Admin API Keys page, nav item, create/edit modal, list and revoke
flows; simplified role label.

* **Tests**
* Extensive unit/integration tests covering store, service, API,
middleware, and UI.
2025-12-27 19:21:44 +09:00
Yota Hamada
345ef13565 feat(schema): logOutput 2025-12-27 19:05:28 +09:00
Yota Hamada
df257b49e8
feat(core): logOutput config option (#1508)
Some checks are pending
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Waiting to run
Frontend CI / Build (push) Waiting to run
Issue #1505
2025-12-27 04:15:30 +09:00
Yota Hamada
03e692d6d8
fix(config): auth token option is not loaded (#1507)
* **New Features**
* Config loading now performs context-aware variable expansion in config
values.
* Evaluation expanded to handle pointer, slice and array fields
(including nested, optional elements and proper error propagation).

* **Tests**
* Added tests verifying env-var expansion during config load (auth token
case).
* Added extensive tests covering pointer, slice and array evaluation
behaviors, substitutions, and error scenarios.
2025-12-27 04:08:43 +09:00
Yota Hamada
3e9d49d1e0
feat(all): multiple commands syntax (#1506)
* **New Features**
* Steps can contain multiple commands executed sequentially; UIs and
runners display and run command lists.

* **Validation**
* Executors advertise capabilities and produce clearer errors when
features conflict or are unsupported (multi-command vs
sub-workflow/script/container).

* **API**
* Step payloads now use a commands list for command entries; legacy
single-command fields remain supported for compatibility.

* **Integration & UX**
* Schemas, APIs, reporting, and UI components updated to present and
handle the commands array.
2025-12-27 03:42:15 +09:00
Yota Hamada
64f1884e45
fix(ui): refactor small UI issues (#1504)
Some checks failed
Frontend CI / Build (push) Has been cancelled
* **New Features**
* Added unsaved changes detection with a confirmation prompt when
attempting to run a DAG with pending edits.

* **Bug Fixes**
* Fixed log line counting and pagination to correctly display actual
rendered lines.
* Improved text selection visibility in log viewers with enhanced
styling.
2025-12-24 21:13:54 +09:00
Yota Hamada
1fb869dda8
feat(ui): add tab layout for DAG definitions page (#1503)
* **New Features**
* Multi-tab support for managing multiple DAGs; tab bar with add/close
controls
* Card view for DAG list on narrow panels; preserved across panel sizes

* **Bug Fixes**
  * Better handling when a DAG is not found or has been deleted
  * Improved modal open/close animation lifecycle

* **UI/UX Improvements**
  * Responsive panel-width measurement and layout refinements
  * Updated table and tab styling; longer fade-in animation
* Disable automatic browser translation; authentication mode
initialization changed
2025-12-24 18:43:01 +09:00
Yota Hamada
05452f9669
feat(ui): Remove discord and github links (#1502)
Some checks failed
Frontend CI / Build (push) Waiting to run
CI / Check for spelling errors (push) Has been cancelled
CI / Go Linter (push) Has been cancelled
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Has been cancelled
* **New Features**
  * Added user menu component to the sidebar
  * Added version display at the bottom of the sidebar

* **UI/UX Changes**
  * Removed Discord and GitHub social links from the sidebar
  * Updated authentication configuration
2025-12-24 00:37:27 +09:00
Yota Hamada
d0193cb8c0
feat(core): Add step-level container field (#1501)
* **New Features**
* Steps can now specify their own container configuration, overriding
DAG-level containers.
* Schema updated to support step-level container field and clarify env
merge behavior.

* **Bug Fixes**
* Validation prevents using both container and executor on the same
step.
* Validation prevents using script together with container (use command
instead).

* **Improvements**
* Enhanced Docker/runtime handling: env merging, working-dir/volumes
shortcuts, runtime evaluation, and richer logs.

* **Tests**
* Added unit and integration tests covering step containers, evaluation
helpers, and conflict cases.
2025-12-24 00:35:42 +09:00
Yota Hamada
8945b926b5 chore: replace image 2025-12-23 11:37:02 +09:00
Yota Hamada
d87b8c6dff
refactor(core/spec): restructure spec types and build logic (#1499)
Some checks are pending
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Waiting to run
* **Refactor**
* Simplified manifest/DAG build flow into a smaller, type-driven
pipeline and reduced public build surface; improved step dependency
handling and defaults.
  * Adjusted DAG name validation behavior.

* **New Features**
* Added typed YAML parsers (shell, env, schedule, ports, continue-on,
string-or-array, parallel) and improved multi-document/base-manifest
decoding.
* Load .env values earlier during runtime initialization to affect
secret resolution.

* **Tests**
* Vastly expanded unit and integration tests for YAML loading, step
building, params, schema resolution, secrets, parallelism, handlers, and
loader options.
2025-12-23 02:19:36 +09:00
Yota Hamada
dcb3cf1570 doc: add example graph 2025-12-22 10:32:21 +09:00
Yota Hamada
d2b3d8fdd3
feat(ui): update frontend design (#1498)
Some checks failed
CI / Check for spelling errors (push) Has been cancelled
CI / Go Linter (push) Has been cancelled
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Has been cancelled
Frontend CI / Build (push) Has been cancelled
2025-12-21 18:59:00 +09:00
Yota Hamada
f3d4577e42
fix(cmd): refactor start command not to prevent DAG from running unexpectedly (#1497)
* **Behavior Changes**
* Removed explicit no-queue and disable-max-active-runs options;
start/retry flows simplified to default local execution and streamlined
retry semantics.

* **New Features**
* Singleton mode now returns clear HTTP 409 conflicts when a singleton
DAG is already running or queued.
* Added top-level run Error field and an API to record early failures
for quicker failure visibility.

* **Bug Fixes**
* Improved process acquisition and restart/retry error handling; tests
updated to reflect local execution behavior.
2025-12-21 18:42:34 +09:00
Yota Hamada
5d90d2744f
feat(spec): add containerName field to DAG-level container (#1496)
Some checks are pending
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Waiting to run
Frontend CI / Build (push) Waiting to run
2025-12-20 21:10:19 +09:00
Yota Hamada
f04786c5e1
feat(all): Add support for multiple-level nested local DAGs (#1493)
* **New Features**
* Optional parentSubDAGRunId query to fetch nested sub-DAG runs; sub-run
lists and modal show root/parent DAG context, status filters with
counts, and improved keyboard navigation.

* **Bug Fixes / UX**
* Consolidated local DAG handling for multi-level execution, improved
error responses for missing runs, and small UI spacing tidy-ups.

* **Tests**
* Reorganized and expanded integration/unit tests for multi-level
sub-DAG scenarios and temp-DAG file handling.
2025-12-20 20:56:08 +09:00
Yota Hamada
887355f0ca chore: update logo image 2025-12-19 09:23:41 +09:00
Yota Hamada
eef457b4c2
fix(store): exact match for DAG name lookup in listRoot (#1490)
Some checks failed
CI / Check for spelling errors (push) Has been cancelled
CI / Go Linter (push) Has been cancelled
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Has been cancelled
* **Bug Fixes**
* Fixed filtering logic for DAG runs to require exact name matching
instead of substring matching, refining which results are returned when
filtering is applied.
2025-12-17 00:33:57 +09:00
Yota Hamada
c72623154d
feat(cmd): add cleanup command (#1489)
* **New Features**
* Added `cleanup` command to remove old DAG run history with
configurable retention periods
* Supports `--dry-run` flag to preview which runs would be removed
without deleting
  * Includes `--yes` flag to skip confirmation prompts
2025-12-16 23:36:24 +09:00
Yota Hamada
5d6e50df04
fix(auth): allow API token to perform write/execute operations (#1486)
Some checks failed
CI / Check for spelling errors (push) Has been cancelled
CI / Go Linter (push) Has been cancelled
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Has been cancelled
Issue #1478 

* **Bug Fixes**
* API tokens now grant admin privileges in builtin authentication mode,
enabling write, execute, and delete operations.

* **Tests**
* New test coverage verifying API token authentication in builtin mode
allows admin-level create, start, and delete actions.
2025-12-15 20:11:20 +09:00
Yota Hamada
cb83c59a6d doc: add new sponsor 2025-12-15 08:45:15 +09:00
Yota Hamada
be3e71b79a
fix(ui): user management visibility and reset password API (#1484)
Some checks failed
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Waiting to run
Frontend CI / Build (push) Has been cancelled
* **Bug Fixes**
* Updated password reset API endpoint and request method for
compatibility.

* **New Features**
* User Management menu item now displays based on admin privileges and
authentication configuration.
  * Reset Password action restricted to admin users only.
2025-12-14 19:22:54 +09:00
Yota Hamada
9841e6ed70
feat(api): add singleton flag to enqueue API (#1483)
* **New Features**
* Added a singleton flag to DAG enqueue requests to prevent duplicate
runs when a DAG is already running or queued.

* **API Changes**
* Enqueue endpoint now can return HTTP 409 Conflict for singleton-mode
conflicts.
  * Added more specific authentication error codes.

* **Behavioral Changes**
* Enqueue/retry paths no longer perform prior maxActiveRuns queue-length
enforcement unless singleton is used.
* Queue concurrency now honors DAG-configured max active runs for
DAG-based queues.

* **Tests**
* Added integration test validating max-active-runs behavior for
DAG-based queues.
2025-12-14 19:05:53 +09:00
Yota Hamada
3ebfa3cbf2
fix(scheduler): queue processor should respect maxConcurrency config in global config (#1482)
* **New Features**
  * Configurable exponential backoff for scheduler retries.
* Global queues from configuration persist when idle and keep configured
concurrency.

* **Improvements**
  * Increased server startup timeout for greater reliability.
* Queue management now distinguishes persistent global queues from
dynamic queues.

* **Tests**
  * Added integration test validating global-queue concurrency.
* Updated scheduler tests to inject a fast backoff config for
deterministic retries.
2025-12-14 17:56:09 +09:00
Kriyanshi
23a29f336a
fix(runtime): include queue name in dequeue command (#1481)
The SubCmdBuilder.Dequeue method was missing the queue name parameter
when building the dequeue command. The CLI command expects the queue
name as the first positional argument after "dequeue", but it was not
being included, causing dequeue operations via the API to fail.

This change:
- Updates Dequeue to use dag.ProcGroup() to get the queue name
- Includes the queue name as the first argument after "dequeue"
- Updates tests to verify the queue name is included
- Adds test case for DAGs with custom queue names

Fixes the issue where dequeuing from the API failed because the queue
name was not being passed to the CLI command.
2025-12-14 17:26:34 +09:00
Yota Hamada
5e7ce83afa
fix(auth): allow API token auth to work alongside builtin mode (#1480)
* **New Features**
* JWT login/usage added; multiple auth methods (JWT, API token, Basic,
OIDC) can be enabled and used together.
* New auth-related API error codes: auth.unauthorized,
auth.token_invalid, auth.forbidden.
* **Behavior**
* Built-in auth mode ignores Basic auth (warning exposed); health and
login endpoints remain public.
* App clears stale built-in tokens at startup; Authorization headers
applied reliably.
* **Tests**
* Extensive auth tests and improved test helpers for auth flows and
headers.
* **Chores**
  * Persisted users path added to test configs.
2025-12-14 17:24:48 +09:00
Yota Hamada
0498640f11
test(spec): add integration tests for special envs (#1479)
Some checks are pending
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Waiting to run
* **Tests**
* Added comprehensive tests validating environment variable visibility
across all handler types and execution phases.

* **Refactor**
* Simplified configuration handling by removing a global locking layer
around config access.

* **Chores**
* Improved test isolation by switching to local configuration instances
to avoid shared global state.
2025-12-14 08:48:27 +09:00
Yota Hamada
a75f667f95
fix(ui): escape config path string injected in html (#1472)
Some checks failed
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test on ${{ matrix.os }} (ubuntu-latest) (push) Waiting to run
Frontend CI / Build (push) Has been cancelled
* **Chores**
* Added Windows build automation scripts for the UI and application
build pipeline with dependency management and error handling.
* Updated frontend template path handling with JavaScript escaping for
enhanced security.
2025-12-11 01:34:27 +09:00
320 changed files with 39251 additions and 12443 deletions

2
.vscode/launch.json vendored
View File

@ -42,7 +42,7 @@
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/cmd/",
"args": ["start", "--no-queue", "${input:pathToSpec}"]
"args": ["start", "${input:pathToSpec}"]
},
{
"name": "Retry",

122
README.md
View File

@ -31,23 +31,17 @@ Built for developers who want powerful workflow orchestration without the operat
## Why Dagu?
### 🚀 Zero Dependencies
**Single binary. No database, no message broker.** Deploy anywhere in seconds—from your laptop to bare metal servers to Kubernetes. Everything is stored in plain files (XDG compliant), making it transparent, portable, and easy to backup.
Many workflow orchestrators already exist, and Apache Airflow is a well known example. In Airflow, DAGs are loaded from Python source files, so defining workflows typically means writing and maintaining Python code. In real deployments, Airflow commonly involves multiple running components (for example, scheduler, webserver, metadata database, and workers) and DAG files often need to be synchronized across them, which can increase operational complexity.
### 🧩 Composable Nested Workflows
**Build complex pipelines from reusable building blocks.** Define sub-workflows that can be called with parameters, executed in parallel, and fully monitored in the UI. See execution traces for every nested level—no black boxes.
Dagu is a self-contained workflow engine where workflows are defined in simple YAML and executed with a single binary. It is designed to run without requiring external databases or message brokers, using local files for definitions, logs, and metadata. Because it orchestrates commands rather than forcing you into a specific programming model, it is easy to integrate existing shell scripts and operational commands as they are. Our goal is to make Dagu an ideal workflow engine for small teams that want orchestration power with minimal setup and operational overhead.
### 🌐 Language Agnostic
**Use your existing scripts without modification.** No need to wrap everything in Python decorators or rewrite logic. Dagu orchestrates shell commands, Docker containers, SSH commands, or HTTP calls—whatever you already have.
### ⚡ Distributed Execution
**Built-in queue system with intelligent task routing.** Route tasks to workers based on labels (GPU, region, compliance requirements). Automatic service registry and health monitoring included—no external coordination service needed.
### 🎯 Production Ready
**Not a toy.** Battle-tested error handling with exponential backoff retries, lifecycle hooks (onSuccess, onFailure, onExit), real-time log streaming, email notifications, Prometheus metrics, and OpenTelemetry tracing out of the box. Built-in user management with role-based access control (RBAC) for team environments.
### 🎨 Modern Web UI
**Beautiful UI that actually helps you debug.** Live log tailing, DAG visualization with Gantt charts, execution history with full lineage, and drill-down into nested sub-workflows. Dark mode included.
## Highlights
- Single binary file installation
- Declarative YAML format for defining DAGs
- Web UI for visually managing, rerunning, and monitoring pipelines
- Use existing programs without any modification
- Self-contained, with no need for a DBMS
## Quick Start
@ -154,6 +148,62 @@ dagu start-all
Visit http://localhost:8080
## Quick Look for Workflow Definitions
### Sequential Steps
Steps execute one after another:
```yaml
type: chain
steps:
- command: echo "Hello, dagu!"
- command: echo "This is a second step"
```
```mermaid
%%{init: {'theme': 'base', 'themeVariables': {'background': '#3A322C', 'primaryTextColor': '#fff', 'lineColor': '#888'}}}%%
graph LR
A["Step 1"] --> B["Step 2"]
style A fill:#3A322C,stroke:green,stroke-width:1.6px,color:#fff
style B fill:#3A322C,stroke:lime,stroke-width:1.6px,color:#fff
```
### Parallel Steps
Steps with dependencies run in parallel:
```yaml
type: graph
steps:
- id: step_1
command: echo "Step 1"
- id: step_2a
command: echo "Step 2a - runs in parallel"
depends: [step_1]
- id: step_2b
command: echo "Step 2b - runs in parallel"
depends: [step_1]
- id: step_3
command: echo "Step 3 - waits for parallel steps"
depends: [step_2a, step_2b]
```
```mermaid
%%{init: {'theme': 'base', 'themeVariables': {'background': '#3A322C', 'primaryTextColor': '#fff', 'lineColor': '#888'}}}%%
graph LR
A[step_1] --> B[step_2a]
A --> C[step_2b]
B --> D[step_3]
C --> D
style A fill:#3A322C,stroke:green,stroke-width:1.6px,color:#fff
style B fill:#3A322C,stroke:lime,stroke-width:1.6px,color:#fff
style C fill:#3A322C,stroke:lime,stroke-width:1.6px,color:#fff
style D fill:#3A322C,stroke:lightblue,stroke-width:1.6px,color:#fff
```
For more examples, see the [Examples](https://docs.dagu.cloud/writing-workflows/examples) documentation.
## Docker-Compose
Clone the repository and run with Docker Compose:
@ -496,6 +546,37 @@ For discussions, support, and sharing ideas, join our community on [Discord](htt
Changelog of recent updates can be found in the [Changelog](https://docs.dagu.cloud/reference/changelog) section of the documentation.
## Acknowledgements
### Sponsors & Supporters
<div align="center">
<h3>💜 Premium Sponsors</h3>
<a href="https://github.com/slashbinlabs">
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fslashbinlabs.png&w=150&h=150&fit=cover&mask=circle" width="100" height="100" alt="@slashbinlabs">
</a>
<h3>✨ Supporters</h3>
<a href="https://github.com/disizmj">
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fdisizmj.png&w=128&h=128&fit=cover&mask=circle" width="50" height="50" alt="@disizmj" style="margin: 5px;">
</a>
<a href="https://github.com/Arvintian">
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2FArvintian.png&w=128&h=128&fit=cover&mask=circle" width="50" height="50" alt="@Arvintian" style="margin: 5px;">
</a>
<a href="https://github.com/yurivish">
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fyurivish.png&w=128&h=128&fit=cover&mask=circle" width="50" height="50" alt="@yurivish" style="margin: 5px;">
</a>
<a href="https://github.com/jayjoshi64">
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fjayjoshi64.png&w=128&h=128&fit=cover&mask=circle" width="50" height="50" alt="@jayjoshi64" style="margin: 5px;">
</a>
<br/><br/>
<a href="https://github.com/sponsors/dagu-org">
<img src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86" width="150" alt="Sponsor">
</a>
</div>
## Contributing
We welcome contributions of all kinds! Whether you're a developer, a designer, or a user, your help is valued. Here are a few ways to get involved:
@ -507,8 +588,6 @@ We welcome contributions of all kinds! Whether you're a developer, a designer, o
For more details, see our [Contribution Guide](./CONTRIBUTING.md).
## Acknowledgements
### Contributors
<a href="https://github.com/dagu-org/dagu/graphs/contributors">
@ -517,15 +596,6 @@ For more details, see our [Contribution Guide](./CONTRIBUTING.md).
Thanks to all the contributors who have helped make Dagu better! Your contributions, whether through code, documentation, or feedback, are invaluable to the project.
### Sponsors & Supporters
<a href="https://github.com/disizmj"><img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fdisizmj.png&w=128&h=128&fit=cover&mask=circle" width="64" height="64" alt="@disizmj"></a>
<a href="https://github.com/Arvintian"><img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2FArvintian.png&w=128&h=128&fit=cover&mask=circle" width="64" height="64" alt="@Arvintian"></a>
<a href="https://github.com/yurivish"><img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fyurivish.png&w=128&h=128&fit=cover&mask=circle" width="64" height="64" alt="@yurivish"></a>
<a href="https://github.com/jayjoshi64"><img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fjayjoshi64.png&w=128&h=128&fit=cover&mask=circle" width="64" height="64" alt="@jayjoshi64"></a>
Thanks for supporting Dagus development! Join our supporters: [GitHub Sponsors](https://github.com/sponsors/dagu-org)
## License
GNU GPLv3 - See [LICENSE](./LICENSE)

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,8 @@ tags:
description: "Authentication operations (login, logout, token management)"
- name: "users"
description: "User management operations (CRUD, password management)"
- name: "api-keys"
description: "API key management operations (admin only)"
paths:
/health:
@ -56,10 +58,6 @@ paths:
default:
description: "Unexpected error"
# ============================================================================
# Authentication Endpoints
# ============================================================================
/auth/login:
post:
summary: "Authenticate user and obtain JWT token"
@ -435,6 +433,225 @@ paths:
schema:
$ref: "#/components/schemas/Error"
# API Key Management (Admin only)
/api-keys:
get:
summary: "List all API keys"
description: "Returns all API keys. Requires admin role."
operationId: "listAPIKeys"
tags:
- "api-keys"
responses:
"200":
description: "List of API keys"
content:
application/json:
schema:
$ref: "#/components/schemas/APIKeysListResponse"
"401":
description: "Not authenticated"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"403":
description: "Requires admin role"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
default:
description: "Error"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
post:
summary: "Create API key"
description: "Full key returned only in this response"
operationId: "createAPIKey"
tags:
- "api-keys"
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CreateAPIKeyRequest"
responses:
"201":
description: "Created"
content:
application/json:
schema:
$ref: "#/components/schemas/CreateAPIKeyResponse"
"400":
description: "Invalid request"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: "Not authenticated"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"403":
description: "Requires admin role"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"409":
description: "Name already exists"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
default:
description: "Error"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/api-keys/{keyId}:
get:
summary: "Get API key"
description: "Returns API key by ID. Requires admin role."
operationId: "getAPIKey"
tags:
- "api-keys"
parameters:
- $ref: "#/components/parameters/APIKeyId"
responses:
"200":
description: "API key details"
content:
application/json:
schema:
$ref: "#/components/schemas/APIKeyResponse"
"401":
description: "Not authenticated"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"403":
description: "Requires admin role"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"404":
description: "Not found"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
default:
description: "Error"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
patch:
summary: "Update API key"
description: "Updates API key info. Requires admin role."
operationId: "updateAPIKey"
tags:
- "api-keys"
parameters:
- $ref: "#/components/parameters/APIKeyId"
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/UpdateAPIKeyRequest"
responses:
"200":
description: "Updated API key"
content:
application/json:
schema:
$ref: "#/components/schemas/APIKeyResponse"
"400":
description: "Invalid request"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: "Not authenticated"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"403":
description: "Requires admin role"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"404":
description: "Not found"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"409":
description: "Name already exists"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
default:
description: "Error"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
delete:
summary: "Delete API key"
description: "Revokes an API key. Requires admin role."
operationId: "deleteAPIKey"
tags:
- "api-keys"
parameters:
- $ref: "#/components/parameters/APIKeyId"
responses:
"204":
description: "API key deleted"
"401":
description: "Not authenticated"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"403":
description: "Requires admin role"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"404":
description: "Not found"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
default:
description: "Error"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/workers:
get:
summary: "List distributed workers"
@ -797,6 +1014,10 @@ paths:
queue:
type: string
description: "Override the DAG-level queue definition"
singleton:
type: boolean
description: "If true, prevent enqueuing if DAG is already running or queued (returns 409 conflict)"
default: false
responses:
"200":
description: "A successful response"
@ -810,6 +1031,12 @@ paths:
description: "ID of the created DAG-run"
required:
- dagRunId
"409":
description: "DAG is already running or queued (singleton mode)"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
default:
description: "Generic error response"
content:
@ -1472,7 +1699,7 @@ paths:
/dag-runs/{name}/{dagRunId}/sub-dag-runs:
get:
summary: "Get sub DAG runs with timing info"
description: "Retrieves timing and status information for all sub DAG runs (including repeated executions) of a specific step"
description: "Retrieves timing and status information for all sub DAG runs (including repeated executions) of a specific step. When parentSubDAGRunId is provided, returns sub-runs of that specific sub DAG run (for multi-level nested DAGs)."
operationId: "getSubDAGRuns"
tags:
- "dag-runs"
@ -1480,6 +1707,12 @@ paths:
- $ref: "#/components/parameters/RemoteNode"
- $ref: "#/components/parameters/DAGName"
- $ref: "#/components/parameters/DAGRunId"
- name: "parentSubDAGRunId"
in: "query"
required: false
schema:
type: "string"
description: "Optional parent sub DAG run ID. When provided, returns sub-runs of this specific sub DAG run instead of the root DAG run. Used for multi-level nested DAGs."
responses:
"200":
description: "A successful response"
@ -1569,6 +1802,37 @@ paths:
schema:
$ref: "#/components/schemas/Error"
/dag-runs/{name}/{dagRunId}/outputs:
get:
summary: "Retrieve collected outputs from a DAG-run"
description: "Fetches the outputs.json file containing all step outputs collected during the DAG-run execution. Returns the outputs as a JSON object where keys are the output names (converted from UPPER_CASE to camelCase by default, or custom key if specified) and values are the captured output strings."
operationId: "getDAGRunOutputs"
tags:
- "dag-runs"
parameters:
- $ref: "#/components/parameters/RemoteNode"
- $ref: "#/components/parameters/DAGName"
- $ref: "#/components/parameters/DAGRunId"
responses:
"200":
description: "Successfully retrieved outputs. Returns the collected outputs with metadata. If the DAG-run completed but captured no outputs, returns an empty outputs object."
content:
application/json:
schema:
$ref: "#/components/schemas/DAGRunOutputs"
"404":
description: "DAG-run not found"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
default:
description: "Generic error response"
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/dag-runs/{name}/{dagRunId}/retry:
post:
summary: "Retry DAG-run execution"
@ -2065,6 +2329,15 @@ components:
type: string
minLength: 1
APIKeyId:
name: keyId
in: path
description: unique identifier of the API key
required: true
schema:
type: string
minLength: 1
PerPage:
name: perPage
in: query
@ -2233,6 +2506,9 @@ components:
- "max_run_reached"
- "not_running"
- "already_exists"
- "auth.unauthorized"
- "auth.token_invalid"
- "auth.forbidden"
Stream:
type: string
@ -2737,6 +3013,52 @@ components:
required:
- nodes
DAGRunOutputs:
type: object
description: "Collected outputs from step executions in a DAG-run, including execution metadata. If the DAG-run completed but no outputs were captured, the outputs object will be empty and metadata fields may be empty strings."
required:
- metadata
- outputs
properties:
metadata:
$ref: "#/components/schemas/OutputsMetadata"
outputs:
type: object
description: "Collected step outputs as key-value pairs. Keys are output names (UPPER_CASE converted to camelCase by default, or custom key if specified) and values are the captured output strings. Empty object if no outputs were captured."
additionalProperties:
type: string
example:
totalCount: "42"
resultFile: "/path/to/result.txt"
config: '{"key": "value"}'
OutputsMetadata:
type: object
description: "Execution context metadata for the outputs"
required:
- dagName
- dagRunId
- attemptId
- status
- completedAt
properties:
dagName:
$ref: "#/components/schemas/DAGName"
dagRunId:
$ref: "#/components/schemas/DAGRunId"
attemptId:
type: string
description: "Attempt identifier within the run"
status:
$ref: "#/components/schemas/StatusLabel"
completedAt:
type: string
format: date-time
description: "RFC3339 timestamp when execution completed"
params:
type: string
description: "JSON-serialized parameters passed to the DAG"
Node:
type: object
description: "Status of an individual step within a DAG-run"
@ -2806,8 +3128,7 @@ components:
description: "Detailed information for a sub DAG-run including timing and status"
properties:
dagRunId:
type: string
description: "Unique identifier for the sub DAG-run"
$ref: "#/components/schemas/DAGRunId"
params:
type: string
description: "Parameters passed to the sub DAG-run in JSON format"
@ -2843,12 +3164,11 @@ components:
dir:
type: string
description: "Working directory for executing the step's command"
cmdWithArgs:
type: string
description: "Complete command string including arguments to execute"
command:
type: string
description: "Base command to execute without arguments"
commands:
type: array
description: "List of commands to execute sequentially"
items:
$ref: "#/components/schemas/CommandEntry"
script:
type: string
description: "Script content if the step executes a script file"
@ -2861,11 +3181,6 @@ components:
output:
type: string
description: "Variable name to store the step's output"
args:
type: array
description: "List of arguments to pass to the command"
items:
type: string
call:
type: string
description: "The name of the DAG to execute as a sub DAG-run"
@ -3046,6 +3361,21 @@ components:
items:
type: integer
CommandEntry:
type: object
description: "A command with its arguments"
properties:
command:
type: string
description: "The command to execute"
args:
type: array
description: "Arguments for the command"
items:
type: string
required:
- command
ListTagResponse:
type: object
description: "Response object for listing all tags"
@ -3124,26 +3454,20 @@ components:
description: "Information about a task currently being executed"
properties:
dagRunId:
type: string
description: "ID of the DAG run being executed"
$ref: "#/components/schemas/DAGRunId"
dagName:
type: string
description: "Name of the DAG being executed"
$ref: "#/components/schemas/DAGName"
startedAt:
type: string
description: "RFC3339 timestamp when the task started"
rootDagRunName:
type: string
description: "Name of the root DAG run"
$ref: "#/components/schemas/DAGName"
rootDagRunId:
type: string
description: "ID of the root DAG run"
$ref: "#/components/schemas/DAGRunId"
parentDagRunName:
type: string
description: "Name of the parent DAG run"
$ref: "#/components/schemas/DAGName"
parentDagRunId:
type: string
description: "ID of the parent DAG run"
$ref: "#/components/schemas/DAGRunId"
required:
- dagRunId
- dagName
@ -3415,6 +3739,118 @@ components:
required:
- users
# ============================================================================
APIKey:
type: object
description: "API key information"
properties:
id:
type: string
description: "Unique identifier"
name:
type: string
description: "Human-readable name"
description:
type: string
description: "Purpose description"
role:
$ref: "#/components/schemas/UserRole"
keyPrefix:
type: string
description: "First 8 characters for identification"
createdAt:
type: string
format: date-time
description: "Creation timestamp"
updatedAt:
type: string
format: date-time
description: "Last update timestamp"
createdBy:
type: string
description: "Creator user ID"
lastUsedAt:
type: string
format: date-time
nullable: true
description: "Last authentication timestamp"
required:
- id
- name
- role
- keyPrefix
- createdAt
- updatedAt
- createdBy
APIKeyResponse:
type: object
description: "API key response"
properties:
apiKey:
$ref: "#/components/schemas/APIKey"
required:
- apiKey
APIKeysListResponse:
type: object
description: "List of API keys"
properties:
apiKeys:
type: array
items:
$ref: "#/components/schemas/APIKey"
required:
- apiKeys
CreateAPIKeyRequest:
type: object
description: "Create API key request"
properties:
name:
type: string
minLength: 1
maxLength: 100
description: "Human-readable name"
description:
type: string
maxLength: 500
description: "Purpose description"
role:
$ref: "#/components/schemas/UserRole"
required:
- name
- role
CreateAPIKeyResponse:
type: object
description: "Create API key response"
properties:
apiKey:
$ref: "#/components/schemas/APIKey"
key:
type: string
description: "Full key secret, only returned once"
required:
- apiKey
- key
UpdateAPIKeyRequest:
type: object
description: "Update API key request"
properties:
name:
type: string
minLength: 1
maxLength: 100
description: "New name"
description:
type: string
maxLength: 500
description: "New description"
role:
$ref: "#/components/schemas/UserRole"
SuccessResponse:
type: object
description: "Generic success response"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 95 KiB

After

Width:  |  Height:  |  Size: 60 KiB

View File

@ -45,6 +45,7 @@ func init() {
rootCmd.AddCommand(cmd.Retry())
rootCmd.AddCommand(cmd.StartAll())
rootCmd.AddCommand(cmd.Migrate())
rootCmd.AddCommand(cmd.Cleanup())
config.Version = version
}

115
internal/auth/apikey.go Normal file
View File

@ -0,0 +1,115 @@
// Copyright (C) 2024 Yota Hamada
// SPDX-License-Identifier: GPL-3.0-or-later
package auth
import (
"time"
"github.com/google/uuid"
)
// APIKey represents a standalone API key in the system.
// API keys are independent entities with their own role assignment,
// enabling programmatic access with fine-grained permissions.
type APIKey struct {
// ID is the unique identifier for the API key (UUID).
ID string `json:"id"`
// Name is a human-readable name for the API key (required).
Name string `json:"name"`
// Description is an optional description of the API key's purpose.
Description string `json:"description,omitempty"`
// Role determines the API key's permissions.
Role Role `json:"role"`
// KeyHash is the bcrypt hash of the API key secret.
// Excluded from JSON serialization for security.
KeyHash string `json:"-"`
// KeyPrefix stores the first 8 characters of the key for identification.
KeyPrefix string `json:"key_prefix"`
// CreatedAt is the timestamp when the API key was created.
CreatedAt time.Time `json:"created_at"`
// UpdatedAt is the timestamp when the API key was last modified.
UpdatedAt time.Time `json:"updated_at"`
// CreatedBy is the user ID of the admin who created the API key.
CreatedBy string `json:"created_by"`
// LastUsedAt is the timestamp when the API key was last used for authentication.
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
}
// NewAPIKey creates an APIKey with a new UUID and sets CreatedAt and UpdatedAt to the current UTC time.
// It validates that required fields are not empty and the role is valid.
// Returns an error if validation fails.
func NewAPIKey(name, description string, role Role, keyHash, keyPrefix, createdBy string) (*APIKey, error) {
if name == "" {
return nil, ErrInvalidAPIKeyName
}
if keyHash == "" {
return nil, ErrInvalidAPIKeyHash
}
if !role.Valid() {
return nil, ErrInvalidRole
}
now := time.Now().UTC()
return &APIKey{
ID: uuid.New().String(),
Name: name,
Description: description,
Role: role,
KeyHash: keyHash,
KeyPrefix: keyPrefix,
CreatedAt: now,
UpdatedAt: now,
CreatedBy: createdBy,
}, nil
}
// APIKeyForStorage is used for JSON serialization to persistent storage.
// It includes the key hash which is excluded from the regular APIKey JSON.
type APIKeyForStorage struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Role Role `json:"role"`
KeyHash string `json:"key_hash"`
KeyPrefix string `json:"key_prefix"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
CreatedBy string `json:"created_by"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
}
// ToStorage converts an APIKey to APIKeyForStorage for persistence.
// NOTE: When adding new fields to APIKey or APIKeyForStorage, ensure both
// ToStorage and ToAPIKey are updated to maintain field synchronization.
func (k *APIKey) ToStorage() *APIKeyForStorage {
return &APIKeyForStorage{
ID: k.ID,
Name: k.Name,
Description: k.Description,
Role: k.Role,
KeyHash: k.KeyHash,
KeyPrefix: k.KeyPrefix,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
CreatedBy: k.CreatedBy,
LastUsedAt: k.LastUsedAt,
}
}
// ToAPIKey converts APIKeyForStorage back to APIKey.
// NOTE: When adding new fields to APIKey or APIKeyForStorage, ensure both
// ToStorage and ToAPIKey are updated to maintain field synchronization.
func (s *APIKeyForStorage) ToAPIKey() *APIKey {
return &APIKey{
ID: s.ID,
Name: s.Name,
Description: s.Description,
Role: s.Role,
KeyHash: s.KeyHash,
KeyPrefix: s.KeyPrefix,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
CreatedBy: s.CreatedBy,
LastUsedAt: s.LastUsedAt,
}
}

View File

@ -0,0 +1,294 @@
// Copyright (C) 2024 Yota Hamada
// SPDX-License-Identifier: GPL-3.0-or-later
package auth
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAPIKey(t *testing.T) {
before := time.Now().UTC()
key, err := NewAPIKey("test-key", "Test description", RoleManager, "hash123", "dagu_tes", "creator-id")
after := time.Now().UTC()
require.NoError(t, err)
assert.NotEmpty(t, key.ID, "ID should be generated")
assert.Equal(t, "test-key", key.Name)
assert.Equal(t, "Test description", key.Description)
assert.Equal(t, RoleManager, key.Role)
assert.Equal(t, "hash123", key.KeyHash)
assert.Equal(t, "dagu_tes", key.KeyPrefix)
assert.Equal(t, "creator-id", key.CreatedBy)
assert.Nil(t, key.LastUsedAt)
// CreatedAt and UpdatedAt should be set to approximately now
assert.True(t, !key.CreatedAt.Before(before), "CreatedAt should be >= before")
assert.True(t, !key.CreatedAt.After(after), "CreatedAt should be <= after")
assert.Equal(t, key.CreatedAt, key.UpdatedAt, "CreatedAt and UpdatedAt should be equal initially")
}
func TestNewAPIKey_Validation(t *testing.T) {
tests := []struct {
name string
keyName string
description string
role Role
keyHash string
keyPrefix string
createdBy string
wantErr error
}{
{
name: "empty name returns error",
keyName: "",
description: "desc",
role: RoleViewer,
keyHash: "hash",
keyPrefix: "prefix",
createdBy: "creator",
wantErr: ErrInvalidAPIKeyName,
},
{
name: "empty key hash returns error",
keyName: "test-key",
description: "desc",
role: RoleViewer,
keyHash: "",
keyPrefix: "prefix",
createdBy: "creator",
wantErr: ErrInvalidAPIKeyHash,
},
{
name: "invalid role returns error",
keyName: "test-key",
description: "desc",
role: Role("invalid"),
keyHash: "hash",
keyPrefix: "prefix",
createdBy: "creator",
wantErr: ErrInvalidRole,
},
{
name: "valid input returns no error",
keyName: "test-key",
description: "desc",
role: RoleViewer,
keyHash: "hash",
keyPrefix: "prefix",
createdBy: "creator",
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
key, err := NewAPIKey(tt.keyName, tt.description, tt.role, tt.keyHash, tt.keyPrefix, tt.createdBy)
if tt.wantErr != nil {
assert.ErrorIs(t, err, tt.wantErr)
assert.Nil(t, key)
} else {
assert.NoError(t, err)
assert.NotNil(t, key)
}
})
}
}
func TestAPIKey_ToStorage(t *testing.T) {
now := time.Now().UTC()
lastUsed := now.Add(-time.Hour)
key := &APIKey{
ID: "key-id",
Name: "test-key",
Description: "Test description",
Role: RoleAdmin,
KeyHash: "hash123",
KeyPrefix: "dagu_tes",
CreatedAt: now,
UpdatedAt: now,
CreatedBy: "creator-id",
LastUsedAt: &lastUsed,
}
storage := key.ToStorage()
assert.Equal(t, key.ID, storage.ID)
assert.Equal(t, key.Name, storage.Name)
assert.Equal(t, key.Description, storage.Description)
assert.Equal(t, key.Role, storage.Role)
assert.Equal(t, key.KeyHash, storage.KeyHash)
assert.Equal(t, key.KeyPrefix, storage.KeyPrefix)
assert.Equal(t, key.CreatedAt, storage.CreatedAt)
assert.Equal(t, key.UpdatedAt, storage.UpdatedAt)
assert.Equal(t, key.CreatedBy, storage.CreatedBy)
require.NotNil(t, storage.LastUsedAt)
assert.Equal(t, *key.LastUsedAt, *storage.LastUsedAt)
}
func TestAPIKeyForStorage_ToAPIKey(t *testing.T) {
now := time.Now().UTC()
lastUsed := now.Add(-time.Hour)
storage := &APIKeyForStorage{
ID: "key-id",
Name: "test-key",
Description: "Test description",
Role: RoleViewer,
KeyHash: "hash456",
KeyPrefix: "dagu_xyz",
CreatedAt: now,
UpdatedAt: now,
CreatedBy: "admin-user",
LastUsedAt: &lastUsed,
}
key := storage.ToAPIKey()
assert.Equal(t, storage.ID, key.ID)
assert.Equal(t, storage.Name, key.Name)
assert.Equal(t, storage.Description, key.Description)
assert.Equal(t, storage.Role, key.Role)
assert.Equal(t, storage.KeyHash, key.KeyHash)
assert.Equal(t, storage.KeyPrefix, key.KeyPrefix)
assert.Equal(t, storage.CreatedAt, key.CreatedAt)
assert.Equal(t, storage.UpdatedAt, key.UpdatedAt)
assert.Equal(t, storage.CreatedBy, key.CreatedBy)
require.NotNil(t, key.LastUsedAt)
assert.Equal(t, *storage.LastUsedAt, *key.LastUsedAt)
}
func TestAPIKey_ToStorage_ToAPIKey_Roundtrip(t *testing.T) {
now := time.Now().UTC()
original := &APIKey{
ID: "key-id",
Name: "roundtrip-key",
Description: "Roundtrip test",
Role: RoleOperator,
KeyHash: "secret-hash",
KeyPrefix: "dagu_rnd",
CreatedAt: now,
UpdatedAt: now,
CreatedBy: "creator",
}
// Convert to storage and back
storage := original.ToStorage()
recovered := storage.ToAPIKey()
assert.Equal(t, original.ID, recovered.ID)
assert.Equal(t, original.Name, recovered.Name)
assert.Equal(t, original.Description, recovered.Description)
assert.Equal(t, original.Role, recovered.Role)
assert.Equal(t, original.KeyHash, recovered.KeyHash)
assert.Equal(t, original.KeyPrefix, recovered.KeyPrefix)
assert.Equal(t, original.CreatedAt, recovered.CreatedAt)
assert.Equal(t, original.UpdatedAt, recovered.UpdatedAt)
assert.Equal(t, original.CreatedBy, recovered.CreatedBy)
assert.Equal(t, original.LastUsedAt, recovered.LastUsedAt)
}
func TestAPIKey_JSONSerialization(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second) // Truncate for JSON round-trip
key := &APIKey{
ID: "key-id",
Name: "json-key",
Description: "JSON test",
Role: RoleAdmin,
KeyHash: "should-be-excluded",
KeyPrefix: "dagu_jsn",
CreatedAt: now,
UpdatedAt: now,
CreatedBy: "creator",
}
// Serialize to JSON
data, err := json.Marshal(key)
require.NoError(t, err)
// KeyHash should NOT be in the JSON (json:"-" tag)
jsonStr := string(data)
assert.NotContains(t, jsonStr, "should-be-excluded")
assert.NotContains(t, jsonStr, "key_hash")
// Deserialize back
var recovered APIKey
err = json.Unmarshal(data, &recovered)
require.NoError(t, err)
assert.Equal(t, key.ID, recovered.ID)
assert.Equal(t, key.Name, recovered.Name)
assert.Equal(t, key.Description, recovered.Description)
assert.Equal(t, key.Role, recovered.Role)
assert.Equal(t, key.KeyPrefix, recovered.KeyPrefix)
assert.Empty(t, recovered.KeyHash, "KeyHash should not be deserialized")
}
func TestAPIKeyForStorage_JSONSerialization(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second) // Truncate for JSON round-trip
storage := &APIKeyForStorage{
ID: "key-id",
Name: "storage-key",
Description: "Storage test",
Role: RoleManager,
KeyHash: "included-hash",
KeyPrefix: "dagu_str",
CreatedAt: now,
UpdatedAt: now,
CreatedBy: "admin",
}
// Serialize to JSON
data, err := json.Marshal(storage)
require.NoError(t, err)
// KeyHash SHOULD be in the JSON for storage
jsonStr := string(data)
assert.Contains(t, jsonStr, "included-hash")
assert.Contains(t, jsonStr, "key_hash")
// Deserialize back
var recovered APIKeyForStorage
err = json.Unmarshal(data, &recovered)
require.NoError(t, err)
assert.Equal(t, storage.ID, recovered.ID)
assert.Equal(t, storage.Name, recovered.Name)
assert.Equal(t, storage.Role, recovered.Role)
assert.Equal(t, storage.KeyHash, recovered.KeyHash)
}
func TestAPIKey_NilLastUsedAt(t *testing.T) {
key := &APIKey{
ID: "key-id",
Name: "nil-lastused",
Role: RoleViewer,
KeyHash: "hash",
KeyPrefix: "dagu_nil",
CreatedBy: "creator",
}
// LastUsedAt is nil by default
assert.Nil(t, key.LastUsedAt)
// Convert to storage and back
storage := key.ToStorage()
assert.Nil(t, storage.LastUsedAt)
recovered := storage.ToAPIKey()
assert.Nil(t, recovered.LastUsedAt)
}
func TestNewAPIKey_GeneratesUniqueIDs(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
key, err := NewAPIKey("test", "", RoleViewer, "hash", "prefix", "creator")
require.NoError(t, err)
assert.False(t, ids[key.ID], "ID should be unique")
ids[key.ID] = true
}
}

View File

@ -0,0 +1,87 @@
// Copyright (C) 2024 Yota Hamada
// SPDX-License-Identifier: GPL-3.0-or-later
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWithUser(t *testing.T) {
t.Run("stores user in context", func(t *testing.T) {
user := NewUser("testuser", "hash", RoleAdmin)
ctx := WithUser(context.Background(), user)
retrieved, ok := UserFromContext(ctx)
require.True(t, ok)
assert.Equal(t, user.ID, retrieved.ID)
assert.Equal(t, user.Username, retrieved.Username)
assert.Equal(t, user.Role, retrieved.Role)
})
t.Run("stores nil user in context", func(t *testing.T) {
ctx := WithUser(context.Background(), nil)
retrieved, ok := UserFromContext(ctx)
assert.True(t, ok)
assert.Nil(t, retrieved)
})
}
func TestUserFromContext(t *testing.T) {
t.Run("returns user when present", func(t *testing.T) {
user := NewUser("admin", "hash", RoleManager)
ctx := WithUser(context.Background(), user)
retrieved, ok := UserFromContext(ctx)
require.True(t, ok)
assert.Equal(t, user, retrieved)
})
t.Run("returns false when user not present", func(t *testing.T) {
ctx := context.Background()
retrieved, ok := UserFromContext(ctx)
assert.False(t, ok)
assert.Nil(t, retrieved)
})
t.Run("returns false when wrong type in context", func(t *testing.T) {
// Manually add wrong type to context with the same key pattern
ctx := context.WithValue(context.Background(), contextKey("auth_user"), "not a user")
retrieved, ok := UserFromContext(ctx)
assert.False(t, ok)
assert.Nil(t, retrieved)
})
t.Run("preserves user through context chain", func(t *testing.T) {
user := NewUser("chainuser", "hash", RoleViewer)
ctx := WithUser(context.Background(), user)
// Add more values to the context chain using a typed key
type otherKey struct{}
ctx = context.WithValue(ctx, otherKey{}, "other_value")
retrieved, ok := UserFromContext(ctx)
require.True(t, ok)
assert.Equal(t, user.Username, retrieved.Username)
})
t.Run("latest user overwrites previous", func(t *testing.T) {
user1 := NewUser("user1", "hash1", RoleAdmin)
user2 := NewUser("user2", "hash2", RoleViewer)
ctx := WithUser(context.Background(), user1)
ctx = WithUser(ctx, user2)
retrieved, ok := UserFromContext(ctx)
require.True(t, ok)
assert.Equal(t, user2.Username, retrieved.Username)
assert.Equal(t, user2.Role, retrieved.Role)
})
}

View File

@ -134,3 +134,25 @@ func TestAllRoles(t *testing.T) {
t.Error("AllRoles() returned a reference to internal state")
}
}
func TestRole_String(t *testing.T) {
tests := []struct {
role Role
want string
}{
{RoleAdmin, "admin"},
{RoleManager, "manager"},
{RoleOperator, "operator"},
{RoleViewer, "viewer"},
{Role("custom"), "custom"},
{Role(""), ""},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
if got := tt.role.String(); got != tt.want {
t.Errorf("Role(%q).String() = %v, want %v", tt.role, got, tt.want)
}
})
}
}

View File

@ -21,6 +21,23 @@ var (
ErrInvalidUserID = errors.New("invalid user ID")
)
// Common errors for API key store operations.
var (
// ErrAPIKeyNotFound is returned when an API key cannot be found.
ErrAPIKeyNotFound = errors.New("API key not found")
// ErrAPIKeyAlreadyExists is returned when attempting to create an API key
// with a name that already exists.
ErrAPIKeyAlreadyExists = errors.New("API key already exists")
// ErrInvalidAPIKeyName is returned when the API key name is invalid.
ErrInvalidAPIKeyName = errors.New("invalid API key name")
// ErrInvalidAPIKeyID is returned when the API key ID is invalid.
ErrInvalidAPIKeyID = errors.New("invalid API key ID")
// ErrInvalidAPIKeyHash is returned when the API key hash is empty.
ErrInvalidAPIKeyHash = errors.New("invalid API key hash")
// ErrInvalidRole is returned when the role is not a valid role.
ErrInvalidRole = errors.New("invalid role")
)
// UserStore defines the interface for user persistence operations.
// Implementations must be safe for concurrent use.
type UserStore interface {
@ -50,3 +67,30 @@ type UserStore interface {
// Count returns the total number of users.
Count(ctx context.Context) (int64, error)
}
// APIKeyStore defines the interface for API key persistence operations.
// Implementations must be safe for concurrent use.
type APIKeyStore interface {
// Create stores a new API key.
// Returns ErrAPIKeyAlreadyExists if an API key with the same name exists.
Create(ctx context.Context, key *APIKey) error
// GetByID retrieves an API key by its unique ID.
// Returns ErrAPIKeyNotFound if the API key does not exist.
GetByID(ctx context.Context, id string) (*APIKey, error)
// List returns all API keys in the store.
List(ctx context.Context) ([]*APIKey, error)
// Update modifies an existing API key.
// Returns ErrAPIKeyNotFound if the API key does not exist.
Update(ctx context.Context, key *APIKey) error
// Delete removes an API key by its ID.
// Returns ErrAPIKeyNotFound if the API key does not exist.
Delete(ctx context.Context, id string) error
// UpdateLastUsed updates the LastUsedAt timestamp for an API key.
// This is called when the API key is used for authentication.
UpdateLastUsed(ctx context.Context, id string) error
}

230
internal/auth/user_test.go Normal file
View File

@ -0,0 +1,230 @@
// Copyright (C) 2024 Yota Hamada
// SPDX-License-Identifier: GPL-3.0-or-later
package auth
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewUser(t *testing.T) {
t.Run("creates user with all fields", func(t *testing.T) {
before := time.Now().UTC()
user := NewUser("testuser", "hashedpassword", RoleManager)
after := time.Now().UTC()
assert.NotEmpty(t, user.ID)
assert.Equal(t, "testuser", user.Username)
assert.Equal(t, "hashedpassword", user.PasswordHash)
assert.Equal(t, RoleManager, user.Role)
assert.True(t, !user.CreatedAt.Before(before))
assert.True(t, !user.CreatedAt.After(after))
assert.Equal(t, user.CreatedAt, user.UpdatedAt)
})
t.Run("generates unique IDs", func(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
user := NewUser("user", "hash", RoleViewer)
assert.False(t, ids[user.ID], "ID should be unique")
ids[user.ID] = true
}
})
t.Run("supports all roles", func(t *testing.T) {
roles := []Role{RoleAdmin, RoleManager, RoleOperator, RoleViewer}
for _, role := range roles {
user := NewUser("user", "hash", role)
assert.Equal(t, role, user.Role)
}
})
t.Run("allows empty username", func(t *testing.T) {
user := NewUser("", "hash", RoleViewer)
assert.Empty(t, user.Username)
assert.NotEmpty(t, user.ID)
})
t.Run("allows empty password hash", func(t *testing.T) {
user := NewUser("user", "", RoleViewer)
assert.Empty(t, user.PasswordHash)
})
}
func TestUser_ToStorage(t *testing.T) {
t.Run("converts all fields", func(t *testing.T) {
now := time.Now().UTC()
user := &User{
ID: "user-123",
Username: "admin",
PasswordHash: "secret-hash",
Role: RoleAdmin,
CreatedAt: now,
UpdatedAt: now.Add(time.Hour),
}
storage := user.ToStorage()
assert.Equal(t, user.ID, storage.ID)
assert.Equal(t, user.Username, storage.Username)
assert.Equal(t, user.PasswordHash, storage.PasswordHash)
assert.Equal(t, user.Role, storage.Role)
assert.Equal(t, user.CreatedAt, storage.CreatedAt)
assert.Equal(t, user.UpdatedAt, storage.UpdatedAt)
})
t.Run("handles empty fields", func(t *testing.T) {
user := &User{}
storage := user.ToStorage()
assert.Empty(t, storage.ID)
assert.Empty(t, storage.Username)
assert.Empty(t, storage.PasswordHash)
})
}
func TestUserForStorage_ToUser(t *testing.T) {
t.Run("converts all fields", func(t *testing.T) {
now := time.Now().UTC()
storage := &UserForStorage{
ID: "storage-456",
Username: "operator",
PasswordHash: "stored-hash",
Role: RoleOperator,
CreatedAt: now,
UpdatedAt: now.Add(2 * time.Hour),
}
user := storage.ToUser()
assert.Equal(t, storage.ID, user.ID)
assert.Equal(t, storage.Username, user.Username)
assert.Equal(t, storage.PasswordHash, user.PasswordHash)
assert.Equal(t, storage.Role, user.Role)
assert.Equal(t, storage.CreatedAt, user.CreatedAt)
assert.Equal(t, storage.UpdatedAt, user.UpdatedAt)
})
t.Run("handles empty fields", func(t *testing.T) {
storage := &UserForStorage{}
user := storage.ToUser()
assert.Empty(t, user.ID)
assert.Empty(t, user.Username)
assert.Empty(t, user.PasswordHash)
})
}
func TestUser_StorageRoundtrip(t *testing.T) {
t.Run("preserves all fields through roundtrip", func(t *testing.T) {
now := time.Now().UTC()
original := &User{
ID: "roundtrip-id",
Username: "roundtrip-user",
PasswordHash: "roundtrip-hash",
Role: RoleManager,
CreatedAt: now,
UpdatedAt: now.Add(time.Minute),
}
storage := original.ToStorage()
recovered := storage.ToUser()
assert.Equal(t, original.ID, recovered.ID)
assert.Equal(t, original.Username, recovered.Username)
assert.Equal(t, original.PasswordHash, recovered.PasswordHash)
assert.Equal(t, original.Role, recovered.Role)
assert.Equal(t, original.CreatedAt, recovered.CreatedAt)
assert.Equal(t, original.UpdatedAt, recovered.UpdatedAt)
})
}
func TestUser_JSONSerialization(t *testing.T) {
t.Run("excludes password hash", func(t *testing.T) {
user := &User{
ID: "json-id",
Username: "jsonuser",
PasswordHash: "should-be-excluded",
Role: RoleViewer,
}
data, err := json.Marshal(user)
require.NoError(t, err)
jsonStr := string(data)
assert.NotContains(t, jsonStr, "should-be-excluded")
assert.NotContains(t, jsonStr, "password_hash")
assert.NotContains(t, jsonStr, "PasswordHash")
})
t.Run("includes other fields", func(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
user := &User{
ID: "json-id",
Username: "jsonuser",
Role: RoleAdmin,
CreatedAt: now,
UpdatedAt: now,
}
data, err := json.Marshal(user)
require.NoError(t, err)
var recovered User
err = json.Unmarshal(data, &recovered)
require.NoError(t, err)
assert.Equal(t, user.ID, recovered.ID)
assert.Equal(t, user.Username, recovered.Username)
assert.Equal(t, user.Role, recovered.Role)
assert.Empty(t, recovered.PasswordHash)
})
}
func TestUserForStorage_JSONSerialization(t *testing.T) {
t.Run("includes password hash", func(t *testing.T) {
storage := &UserForStorage{
ID: "storage-id",
Username: "storageuser",
PasswordHash: "included-hash",
Role: RoleManager,
}
data, err := json.Marshal(storage)
require.NoError(t, err)
jsonStr := string(data)
assert.Contains(t, jsonStr, "included-hash")
assert.Contains(t, jsonStr, "password_hash")
})
t.Run("roundtrip preserves all fields", func(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
original := &UserForStorage{
ID: "storage-rt-id",
Username: "storage-rt-user",
PasswordHash: "storage-rt-hash",
Role: RoleOperator,
CreatedAt: now,
UpdatedAt: now,
}
data, err := json.Marshal(original)
require.NoError(t, err)
var recovered UserForStorage
err = json.Unmarshal(data, &recovered)
require.NoError(t, err)
assert.Equal(t, original.ID, recovered.ID)
assert.Equal(t, original.Username, recovered.Username)
assert.Equal(t, original.PasswordHash, recovered.PasswordHash)
assert.Equal(t, original.Role, recovered.Role)
})
}

134
internal/cmd/cleanup.go Normal file
View File

@ -0,0 +1,134 @@
package cmd
import (
"bufio"
"fmt"
"os"
"strconv"
"strings"
"github.com/dagu-org/dagu/internal/core/execution"
"github.com/spf13/cobra"
)
// Cleanup creates and returns a cobra command for removing old DAG run history.
func Cleanup() *cobra.Command {
return NewCommand(
&cobra.Command{
Use: "cleanup [flags] <DAG name>",
Short: "Remove old DAG run history",
Long: `Remove old DAG run history for a specified DAG.
By default, removes all history except for currently active runs.
Use --retention-days to keep recent history.
Active runs are never deleted for safety.
Examples:
dagu cleanup my-workflow # Delete all history (with confirmation)
dagu cleanup --retention-days 30 my-workflow # Keep last 30 days
dagu cleanup --dry-run my-workflow # Preview what would be deleted
dagu cleanup -y my-workflow # Skip confirmation
`,
Args: cobra.ExactArgs(1),
},
cleanupFlags,
runCleanup,
)
}
var cleanupFlags = []commandLineFlag{
retentionDaysFlag,
dryRunFlag,
yesFlag,
}
func runCleanup(ctx *Context, args []string) error {
dagName := args[0]
// Parse retention days (flags are string-based in this codebase)
retentionStr, err := ctx.StringParam("retention-days")
if err != nil {
return fmt.Errorf("failed to get retention-days: %w", err)
}
retentionDays, err := strconv.Atoi(retentionStr)
if err != nil {
return fmt.Errorf("invalid retention-days value %q: must be a non-negative integer", retentionStr)
}
// Reject negative retention (clearer error than silent no-op)
if retentionDays < 0 {
return fmt.Errorf("retention-days cannot be negative (got %d)", retentionDays)
}
// Get boolean flags
dryRun, _ := ctx.Command.Flags().GetBool("dry-run")
skipConfirm, _ := ctx.Command.Flags().GetBool("yes")
// Build description message
var actionDesc string
if retentionDays == 0 {
actionDesc = fmt.Sprintf("all history for DAG %q", dagName)
} else {
actionDesc = fmt.Sprintf("history older than %d days for DAG %q", retentionDays, dagName)
}
// Build options for RemoveOldDAGRuns
var opts []execution.RemoveOldDAGRunsOption
if dryRun {
opts = append(opts, execution.WithDryRun())
}
// Dry run mode - show what would be deleted
if dryRun {
runIDs, err := ctx.DAGRunStore.RemoveOldDAGRuns(ctx, dagName, retentionDays, opts...)
if err != nil {
return fmt.Errorf("failed to check history for %q: %w", dagName, err)
}
if len(runIDs) == 0 {
fmt.Printf("Dry run: No runs to delete for DAG %q\n", dagName)
} else {
fmt.Printf("Dry run: Would delete %d run(s) for DAG %q:\n", len(runIDs), dagName)
for _, runID := range runIDs {
fmt.Printf(" - %s\n", runID)
}
}
return nil
}
// Confirmation prompt (unless --yes or --quiet)
if !skipConfirm && !ctx.Quiet {
fmt.Printf("This will delete %s.\n", actionDesc)
fmt.Println("Active runs will be preserved.")
fmt.Print("Continue? [y/N]: ")
reader := bufio.NewReader(os.Stdin)
response, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read user input: %w", err)
}
response = strings.TrimSpace(strings.ToLower(response))
if response != "y" && response != "yes" {
fmt.Println("Cancelled.")
return nil
}
}
// Execute cleanup using the existing DAGRunStore method
runIDs, err := ctx.DAGRunStore.RemoveOldDAGRuns(ctx, dagName, retentionDays, opts...)
if err != nil {
return fmt.Errorf("failed to cleanup history for %q: %w", dagName, err)
}
if !ctx.Quiet {
if len(runIDs) == 0 {
fmt.Printf("No runs to delete for DAG %q\n", dagName)
} else {
fmt.Printf("Successfully removed %d run(s) for DAG %q\n", len(runIDs), dagName)
}
}
return nil
}

View File

@ -0,0 +1,280 @@
package cmd_test
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/dagu-org/dagu/internal/cmd"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/execution"
"github.com/dagu-org/dagu/internal/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCleanupCommand(t *testing.T) {
t.Run("DeletesAllHistoryWithRetentionZero", func(t *testing.T) {
th := test.SetupCommand(t)
// Create a DAG and run it to generate history
dag := th.DAG(t, `steps:
- name: "1"
command: echo "hello"
`)
// Run the DAG to create history
th.RunCommand(t, cmd.Start(), test.CmdTest{
Args: []string{"start", dag.Location},
})
// Wait for DAG to complete
dag.AssertLatestStatus(t, core.Succeeded)
// Verify history exists
dag.AssertDAGRunCount(t, 1)
// Run cleanup with --yes to skip confirmation
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
Args: []string{"cleanup", "--yes", dag.Name},
})
// Verify history is deleted
dag.AssertDAGRunCount(t, 0)
})
t.Run("PreservesRecentHistoryWithRetentionDays", func(t *testing.T) {
th := test.SetupCommand(t)
// Create a DAG and run it
dag := th.DAG(t, `steps:
- name: "1"
command: echo "hello"
`)
// Run the DAG to create history
th.RunCommand(t, cmd.Start(), test.CmdTest{
Args: []string{"start", dag.Location},
})
// Wait for DAG to complete
dag.AssertLatestStatus(t, core.Succeeded)
// Verify history exists
dag.AssertDAGRunCount(t, 1)
// Run cleanup with retention of 30 days (should keep recent history)
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
Args: []string{"cleanup", "--retention-days", "30", "--yes", dag.Name},
})
// Verify history is still there (it's less than 30 days old)
dag.AssertDAGRunCount(t, 1)
})
t.Run("DryRunDoesNotDelete", func(t *testing.T) {
th := test.SetupCommand(t)
// Create a DAG and run it
dag := th.DAG(t, `steps:
- name: "1"
command: echo "hello"
`)
// Run the DAG to create history
th.RunCommand(t, cmd.Start(), test.CmdTest{
Args: []string{"start", dag.Location},
})
// Wait for DAG to complete
dag.AssertLatestStatus(t, core.Succeeded)
// Verify history exists
dag.AssertDAGRunCount(t, 1)
// Run cleanup with --dry-run
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
Args: []string{"cleanup", "--dry-run", dag.Name},
})
// Verify history is still there (dry run should not delete)
dag.AssertDAGRunCount(t, 1)
})
t.Run("PreservesActiveRuns", func(t *testing.T) {
th := test.SetupCommand(t)
// Create a DAG that runs for a while
dag := th.DAG(t, `steps:
- name: "1"
command: sleep 30
`)
done := make(chan struct{})
go func() {
// Start the DAG
th.RunCommand(t, cmd.Start(), test.CmdTest{
Args: []string{"start", dag.Location},
})
close(done)
}()
// Wait for DAG to start running
time.Sleep(time.Millisecond * 200)
dag.AssertLatestStatus(t, core.Running)
// Try to cleanup while running (nothing to delete since only active run exists)
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
Args: []string{"cleanup", "--yes", dag.Name},
})
// Verify the running DAG is still there (should be preserved)
dag.AssertLatestStatus(t, core.Running)
// Stop the DAG
th.RunCommand(t, cmd.Stop(), test.CmdTest{
Args: []string{"stop", dag.Location},
})
<-done
})
t.Run("RejectsNegativeRetentionDays", func(t *testing.T) {
th := test.SetupCommand(t)
dag := th.DAG(t, `steps:
- name: "1"
command: echo "hello"
`)
err := th.RunCommandWithError(t, cmd.Cleanup(), test.CmdTest{
Args: []string{"cleanup", "--retention-days", "-1", dag.Name},
})
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot be negative")
})
t.Run("RejectsInvalidRetentionDays", func(t *testing.T) {
th := test.SetupCommand(t)
dag := th.DAG(t, `steps:
- name: "1"
command: echo "hello"
`)
err := th.RunCommandWithError(t, cmd.Cleanup(), test.CmdTest{
Args: []string{"cleanup", "--retention-days", "abc", dag.Name},
})
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid retention-days")
})
t.Run("RequiresDAGNameArgument", func(t *testing.T) {
th := test.SetupCommand(t)
err := th.RunCommandWithError(t, cmd.Cleanup(), test.CmdTest{
Args: []string{"cleanup", "--yes"},
})
require.Error(t, err)
assert.Contains(t, err.Error(), "accepts 1 arg")
})
t.Run("SucceedsForNonExistentDAG", func(t *testing.T) {
th := test.SetupCommand(t)
// Cleanup for a DAG that doesn't exist should succeed silently
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
Args: []string{"cleanup", "--yes", "non-existent-dag"},
})
})
}
func TestCleanupCommandDirectStore(t *testing.T) {
// Test cleanup using the DAGRunStore directly to verify underlying behavior
t.Run("RemoveOldDAGRunsWithStore", func(t *testing.T) {
th := test.Setup(t)
dagName := "test-cleanup-dag"
// Create old DAG runs directly in the store
oldTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
recentTime := time.Now()
// Create a minimal DAG for the test
testDAG := &core.DAG{Name: dagName}
// Create an old run
oldAttempt, err := th.DAGRunStore.CreateAttempt(
th.Context,
testDAG,
oldTime,
"old-run-id",
execution.NewDAGRunAttemptOptions{},
)
require.NoError(t, err)
require.NoError(t, oldAttempt.Open(th.Context))
require.NoError(t, oldAttempt.Write(th.Context, execution.DAGRunStatus{
Name: dagName,
DAGRunID: "old-run-id",
Status: core.Succeeded,
}))
require.NoError(t, oldAttempt.Close(th.Context))
// Create a recent run
recentAttempt, err := th.DAGRunStore.CreateAttempt(
th.Context,
testDAG,
recentTime,
"recent-run-id",
execution.NewDAGRunAttemptOptions{},
)
require.NoError(t, err)
require.NoError(t, recentAttempt.Open(th.Context))
require.NoError(t, recentAttempt.Write(th.Context, execution.DAGRunStatus{
Name: dagName,
DAGRunID: "recent-run-id",
Status: core.Succeeded,
}))
require.NoError(t, recentAttempt.Close(th.Context))
// Manually set old file modification time
setOldModTime(t, th.Config.Paths.DAGRunsDir, dagName, "", oldTime)
// Verify both runs exist
runs := th.DAGRunStore.RecentAttempts(th.Context, dagName, 10)
require.Len(t, runs, 2)
// Remove runs older than 7 days
removedIDs, err := th.DAGRunStore.RemoveOldDAGRuns(th.Context, dagName, 7)
require.NoError(t, err)
assert.Len(t, removedIDs, 1)
// Verify old run is deleted, recent run remains
runs = th.DAGRunStore.RecentAttempts(th.Context, dagName, 10)
require.Len(t, runs, 1)
status, err := runs[0].ReadStatus(th.Context)
require.NoError(t, err)
assert.Equal(t, "recent-run-id", status.DAGRunID)
})
}
// setOldModTime sets old modification time on DAG run files
func setOldModTime(t *testing.T, baseDir, dagName, _ string, modTime time.Time) {
t.Helper()
// Find the run directory
dagRunsDir := filepath.Join(baseDir, dagName, "dag-runs")
err := filepath.Walk(dagRunsDir, func(path string, _ os.FileInfo, err error) error {
if err != nil {
return err
}
// Set mod time on all files and directories
return os.Chtimes(path, modTime, modTime)
})
// Ignore errors if directory doesn't exist
if err != nil && !os.IsNotExist(err) {
t.Logf("Warning: failed to set mod time: %v", err)
}
}

View File

@ -2,6 +2,7 @@ package cmd
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
@ -27,6 +28,7 @@ import (
"github.com/dagu-org/dagu/internal/persistence/filequeue"
"github.com/dagu-org/dagu/internal/persistence/fileserviceregistry"
"github.com/dagu-org/dagu/internal/runtime"
"github.com/dagu-org/dagu/internal/runtime/transform"
"github.com/dagu-org/dagu/internal/service/coordinator"
"github.com/dagu-org/dagu/internal/service/frontend"
"github.com/dagu-org/dagu/internal/service/resource"
@ -355,9 +357,7 @@ func (c *Context) GenLogFileName(dag *core.DAG, dagRunID string) (string, error)
// NewCommand creates a new command instance with the given cobra command and run function.
func NewCommand(cmd *cobra.Command, flags []commandLineFlag, runFunc func(cmd *Context, args []string) error) *cobra.Command {
config.WithViperLock(func() {
initFlags(cmd, flags...)
})
initFlags(cmd, flags...)
cmd.SilenceUsage = true
@ -490,6 +490,53 @@ func (cfg LogConfig) LogDir() (string, error) {
return logDir, nil
}
// RecordEarlyFailure records a failure in the execution history before the DAG has fully started.
// This is used for infrastructure errors like singleton conflicts or process acquisition failures.
func (c *Context) RecordEarlyFailure(dag *core.DAG, dagRunID string, err error) error {
if dag == nil || dagRunID == "" {
return fmt.Errorf("DAG and dag-run ID are required to record failure")
}
// 1. Check if a DAGRunAttempt already exists for the given run-id.
ref := execution.NewDAGRunRef(dag.Name, dagRunID)
attempt, findErr := c.DAGRunStore.FindAttempt(c, ref)
if findErr != nil && !errors.Is(findErr, execution.ErrDAGRunIDNotFound) {
return fmt.Errorf("failed to check for existing attempt: %w", findErr)
}
if attempt == nil {
// 2. Create the attempt if not exists
att, createErr := c.DAGRunStore.CreateAttempt(c, dag, time.Now(), dagRunID, execution.NewDAGRunAttemptOptions{})
if createErr != nil {
return fmt.Errorf("failed to create run to record failure: %w", createErr)
}
attempt = att
}
// 3. Construct the "Failed" status
statusBuilder := transform.NewStatusBuilder(dag)
logPath, _ := c.GenLogFileName(dag, dagRunID)
status := statusBuilder.Create(dagRunID, core.Failed, 0, time.Now(),
transform.WithLogFilePath(logPath),
transform.WithFinishedAt(time.Now()),
transform.WithError(err.Error()),
)
// 4. Write the status
if err := attempt.Open(c); err != nil {
return fmt.Errorf("failed to open attempt for recording failure: %w", err)
}
defer func() {
_ = attempt.Close(c)
}()
if err := attempt.Write(c, status); err != nil {
return fmt.Errorf("failed to write failed status: %w", err)
}
return nil
}
// LogFile constructs the log filename using the prefix, safe DAG name, current timestamp,
// and a truncated version of the dag-run ID.
func (cfg LogConfig) LogFile() string {

View File

@ -167,6 +167,8 @@ func TestModel(t *testing.T) {
})
}
var _ execution.DAGStore = (*mockDAGStore)(nil)
// mockDAGStore is a mock implementation of models.DAGStore
type mockDAGStore struct {
mock.Mock

View File

@ -65,21 +65,6 @@ func runEnqueue(ctx *Context, args []string) error {
dag.Queue = queueOverride
}
// Check queued DAG-runs
queuedRuns, err := ctx.QueueStore.ListByDAGName(ctx, dag.ProcGroup(), dag.Name)
if err != nil {
return fmt.Errorf("failed to read queue: %w", err)
}
// If the DAG has a queue configured and maxActiveRuns > 1, ensure the number
// of active runs in the queue does not exceed this limit.
// No need to check if maxActiveRuns <= 1 for enqueueing as queue level
// maxConcurrency will be the only cap.
if dag.Queue != "" && dag.MaxActiveRuns > 1 && len(queuedRuns) >= dag.MaxActiveRuns {
// The same DAG is already in the queue
return fmt.Errorf("DAG %s is already in the queue (maxActiveRuns=%d), cannot enqueue", dag.Name, dag.MaxActiveRuns)
}
return enqueueDAGRun(ctx, dag, runID)
}

View File

@ -1,9 +1,7 @@
package cmd
import (
"errors"
"fmt"
"log/slog"
"os"
"strings"
@ -21,7 +19,6 @@ const (
flagWorkdir = "workdir"
flagShell = "shell"
flagBase = "base"
flagSingleton = "singleton"
defaultStepName = "main"
execCommandUsage = "exec [flags] -- <command> [args...]"
)
@ -30,12 +27,9 @@ var (
execFlags = []commandLineFlag{
dagRunIDFlag,
nameFlag,
queueFlag,
noQueueFlag,
workdirFlag,
shellFlag,
baseFlag,
singletonFlag,
}
)
@ -49,7 +43,7 @@ func Exec() *cobra.Command {
Examples:
dagu exec -- echo "hello world"
dagu exec --env FOO=bar -- sh -c 'echo $FOO'
dagu exec --queue nightly --worker-label role=batch -- python nightly.py`,
dagu exec --worker-label role=batch -- python remote_script.py`,
Args: cobra.ArbitraryArgs,
}
@ -62,14 +56,12 @@ Examples:
return command
}
// runExec parses flags and arguments and executes the provided command as an inline DAG run,
// either enqueueing it for distributed execution or running it immediately in-process.
// It validates inputs (run-id, working directory, base and dotenv files, env vars, worker labels,
// queue/singleton flags), builds the DAG for the inline command, and chooses between enqueueing
// (when queues/worker labels require it or when max runs are reached) or direct execution.
// runExec parses flags and arguments and executes the provided command as an inline DAG run.
// It validates inputs (run-id, working directory, base and dotenv files, env vars, worker labels),
// builds the DAG for the inline command, and executes it locally.
// ctx provides CLI and application context; args are the command and its arguments.
// Returns an error for validation failures, when a dag-run with the same run-id already exists,
// or if enqueueing/execution fails.
// or if execution fails.
func runExec(ctx *Context, args []string) error {
if len(args) == 0 {
return fmt.Errorf("command is required (try: dagu exec -- <command>)")
@ -177,29 +169,10 @@ func runExec(ctx *Context, args []string) error {
workerLabels[key] = value
}
queueName, err := ctx.Command.Flags().GetString("queue")
if err != nil {
return fmt.Errorf("failed to read queue flag: %w", err)
}
noQueue, err := ctx.Command.Flags().GetBool("no-queue")
if err != nil {
return fmt.Errorf("failed to read no-queue flag: %w", err)
}
singleton, err := ctx.Command.Flags().GetBool(flagSingleton)
if err != nil {
return fmt.Errorf("failed to read singleton flag: %w", err)
}
queueDisabled := !ctx.Config.Queues.Enabled || noQueue
if len(workerLabels) > 0 {
if !ctx.Config.Queues.Enabled {
return fmt.Errorf("worker selector requires queues; enable queues or remove --worker-label")
}
if noQueue {
return fmt.Errorf("--worker-label cannot be combined with --no-queue")
}
queueDisabled = false
}
opts := ExecOptions{
@ -210,9 +183,6 @@ func runExec(ctx *Context, args []string) error {
Env: envVars,
DotenvFiles: dotenvPaths,
BaseConfig: baseConfig,
Queue: queueName,
NoQueue: noQueue,
Singleton: singleton,
WorkerLabels: workerLabels,
}
@ -223,28 +193,6 @@ func runExec(ctx *Context, args []string) error {
dagRunRef := execution.NewDAGRunRef(dag.Name, runID)
if !queueDisabled && len(workerLabels) > 0 {
logger.Info(ctx, "Queueing inline dag-run for distributed execution",
tag.DAG(dag.Name),
tag.RunID(runID),
slog.Any("worker-selector", workerLabels),
tag.Command(strings.Join(args, " ")),
)
dag.Location = ""
return enqueueDAGRun(ctx, dag, runID)
}
if !queueDisabled && dag.Queue != "" {
logger.Info(ctx, "Queueing inline dag-run",
tag.DAG(dag.Name),
tag.Queue(dag.Queue),
tag.RunID(runID),
tag.Command(strings.Join(args, " ")),
)
dag.Location = ""
return enqueueDAGRun(ctx, dag, runID)
}
attempt, _ := ctx.DAGRunStore.FindAttempt(ctx, dagRunRef)
if attempt != nil {
return fmt.Errorf("dag-run ID %s already exists for DAG %s", runID, dag.Name)
@ -256,16 +204,7 @@ func runExec(ctx *Context, args []string) error {
tag.RunID(runID),
)
err = tryExecuteDAG(ctx, dag, runID, dagRunRef, false)
if errors.Is(err, errMaxRunReached) && !queueDisabled {
logger.Info(ctx, "Max active runs reached; enqueueing dag-run instead",
tag.DAG(dag.Name),
tag.RunID(runID),
)
dag.Location = ""
return enqueueDAGRun(ctx, dag, runID)
}
return err
return tryExecuteDAG(ctx, dag, runID, dagRunRef)
}
var (
@ -281,9 +220,4 @@ var (
name: flagBase,
usage: "Path to a base DAG YAML whose defaults are applied before inline overrides",
}
singletonFlag = commandLineFlag{
name: flagSingleton,
usage: "Limit execution to a single active run (sets maxActiveRuns=1)",
isBool: true,
}
)

View File

@ -6,6 +6,7 @@ import (
"path/filepath"
"strings"
"github.com/dagu-org/dagu/internal/common/cmdutil"
"github.com/dagu-org/dagu/internal/common/fileutil"
"github.com/dagu-org/dagu/internal/common/stringutil"
"github.com/dagu-org/dagu/internal/core"
@ -22,9 +23,6 @@ type ExecOptions struct {
Env []string
DotenvFiles []string
BaseConfig string
Queue string
NoQueue bool
Singleton bool
WorkerLabels map[string]string
}
@ -34,16 +32,14 @@ type execSpec struct {
WorkingDir string `yaml:"workingDir,omitempty"`
Env []string `yaml:"env,omitempty"`
Dotenv []string `yaml:"dotenv,omitempty"`
MaxActiveRuns int `yaml:"maxActiveRuns,omitempty"`
Queue string `yaml:"queue,omitempty"`
WorkerSelector map[string]string `yaml:"workerSelector,omitempty"`
Steps []execStep `yaml:"steps"`
}
type execStep struct {
Name string `yaml:"name"`
Command []string `yaml:"command,omitempty"`
Shell string `yaml:"shell,omitempty"`
Name string `yaml:"name"`
Command string `yaml:"command,omitempty"`
Shell string `yaml:"shell,omitempty"`
}
func buildExecDAG(ctx *Context, opts ExecOptions) (*core.DAG, string, error) {
@ -59,15 +55,8 @@ func buildExecDAG(ctx *Context, opts ExecOptions) (*core.DAG, string, error) {
return nil, "", fmt.Errorf("invalid DAG name: %w", err)
}
maxActiveRuns := -1
if opts.Singleton {
maxActiveRuns = 1
}
queueValue := ""
if opts.Queue != "" && !opts.NoQueue {
queueValue = opts.Queue
}
// Build command string from args
commandStr := cmdutil.BuildCommandEscapedString(opts.CommandArgs[0], opts.CommandArgs[1:])
specDoc := execSpec{
Name: name,
@ -75,13 +64,11 @@ func buildExecDAG(ctx *Context, opts ExecOptions) (*core.DAG, string, error) {
WorkingDir: opts.WorkingDir,
Env: opts.Env,
Dotenv: opts.DotenvFiles,
MaxActiveRuns: maxActiveRuns,
Queue: queueValue,
WorkerSelector: opts.WorkerLabels,
Steps: []execStep{
{
Name: defaultStepName,
Command: opts.CommandArgs,
Command: commandStr,
Shell: opts.ShellOverride,
},
},
@ -127,19 +114,10 @@ func buildExecDAG(ctx *Context, opts ExecOptions) (*core.DAG, string, error) {
dag.Name = name
dag.WorkingDir = opts.WorkingDir
if opts.Queue != "" && !opts.NoQueue {
dag.Queue = opts.Queue
} else if opts.NoQueue {
dag.Queue = ""
}
if len(opts.WorkerLabels) > 0 {
dag.WorkerSelector = opts.WorkerLabels
}
if opts.Singleton {
dag.MaxActiveRuns = 1
} else {
dag.MaxActiveRuns = -1
}
dag.MaxActiveRuns = -1
dag.Location = ""
return dag, string(specYAML), nil

View File

@ -1,7 +1,6 @@
package cmd
import (
"github.com/dagu-org/dagu/internal/common/config"
"github.com/dagu-org/dagu/internal/common/stringutil"
"github.com/spf13/cobra"
"github.com/spf13/viper"
@ -68,14 +67,6 @@ var (
usage: "Override the DAG name (default: name from DAG definition or filename)",
}
// noQueueFlag is used to indicate that the dag-run should not be queued and should be executed immediately.
noQueueFlag = commandLineFlag{
name: "no-queue",
usage: "Do not queue the dag-run, execute immediately",
isBool: true,
shorthand: "n",
}
// Unique dag-run ID required for retrying a dag-run.
// This flag must be provided when using the retry command.
dagRunIDFlagRetry = commandLineFlag{
@ -93,15 +84,6 @@ var (
defaultValue: "",
}
// noCheckMaxActiveRuns
disableMaxActiveRuns = commandLineFlag{
name: "disable-max-active-runs",
shorthand: "",
usage: "Disable check for max active runs",
isBool: true,
defaultValue: "",
}
// Unique dag-run ID used for starting a new dag-run.
// This is used to track and identify the execution instance and its status.
dagRunIDFlag = commandLineFlag{
@ -267,6 +249,30 @@ var (
isBool: true,
bindViper: true,
}
// retentionDaysFlag specifies the number of days to retain history.
// Records older than this will be deleted.
// If set to 0, all records (except active) will be deleted.
retentionDaysFlag = commandLineFlag{
name: "retention-days",
defaultValue: "0",
usage: "Number of days to retain history (0 = delete all, except active runs)",
}
// dryRunFlag enables preview mode without actual deletion.
dryRunFlag = commandLineFlag{
name: "dry-run",
usage: "Preview what would be deleted without actually deleting",
isBool: true,
}
// yesFlag skips the confirmation prompt.
yesFlag = commandLineFlag{
name: "yes",
shorthand: "y",
usage: "Skip confirmation prompt",
isBool: true,
}
)
type commandLineFlag struct {
@ -296,16 +302,13 @@ func initFlags(cmd *cobra.Command, additionalFlags ...commandLineFlag) {
// bindFlags binds command-line flags to the provided Viper instance for configuration lookup.
// It binds only flags whose `bindViper` field is true, using the camel-cased key produced
// from each flag's kebab-case name. Binding is performed while holding the config package's
// Viper lock to ensure thread-safe registration.
// from each flag's kebab-case name.
func bindFlags(viper *viper.Viper, cmd *cobra.Command, additionalFlags ...commandLineFlag) {
flags := append([]commandLineFlag{configFlag}, additionalFlags...)
config.WithViperLock(func() {
for _, flag := range flags {
if flag.bindViper {
_ = viper.BindPFlag(stringutil.KebabToCamel(flag.name), cmd.Flags().Lookup(flag.name))
}
for _, flag := range flags {
if flag.bindViper {
_ = viper.BindPFlag(stringutil.KebabToCamel(flag.name), cmd.Flags().Lookup(flag.name))
}
})
}
}

View File

@ -17,6 +17,8 @@ import (
"github.com/stretchr/testify/require"
)
var _ execution.DAGStore = (*mockDAGStore)(nil)
// mockDAGStore implements models.DAGStore for testing
type mockDAGStore struct {
dags map[string]*core.DAG

View File

@ -0,0 +1,182 @@
package cmd_test
import (
"errors"
"testing"
"time"
"github.com/dagu-org/dagu/internal/cmd"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/execution"
"github.com/dagu-org/dagu/internal/test"
"github.com/stretchr/testify/require"
)
func TestRecordEarlyFailure(t *testing.T) {
t.Run("RecordsFailureForNewDAGRun", func(t *testing.T) {
th := test.SetupCommand(t)
dag := th.DAG(t, `
steps:
- name: step1
command: echo hello
`)
dagRunID := "test-run-id-001"
testErr := errors.New("process acquisition failed")
// Create Context with required stores
ctx := &cmd.Context{
Context: th.Context,
Config: th.Config,
DAGRunStore: th.DAGRunStore,
}
// Record the early failure
err := ctx.RecordEarlyFailure(dag.DAG, dagRunID, testErr)
require.NoError(t, err)
// Verify the failure was recorded
ref := execution.NewDAGRunRef(dag.Name, dagRunID)
attempt, err := th.DAGRunStore.FindAttempt(th.Context, ref)
require.NoError(t, err)
require.NotNil(t, attempt)
status, err := attempt.ReadStatus(th.Context)
require.NoError(t, err)
require.Equal(t, core.Failed, status.Status)
require.Contains(t, status.Error, "process acquisition failed")
})
t.Run("RecordsFailureForExistingAttempt", func(t *testing.T) {
th := test.SetupCommand(t)
dag := th.DAG(t, `
steps:
- name: step1
command: echo hello
`)
// First, run the DAG to create an attempt
th.RunCommand(t, cmd.Start(), test.CmdTest{
Args: []string{"start", dag.Location},
})
// Get the existing run ID
latestStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
require.NoError(t, err)
dagRunID := latestStatus.DAGRunID
// Now record an early failure for the same run ID
testErr := errors.New("retry failed due to lock contention")
ctx := &cmd.Context{
Context: th.Context,
Config: th.Config,
DAGRunStore: th.DAGRunStore,
}
err = ctx.RecordEarlyFailure(dag.DAG, dagRunID, testErr)
require.NoError(t, err)
// Verify the failure was recorded (status should be updated)
ref := execution.NewDAGRunRef(dag.Name, dagRunID)
attempt, err := th.DAGRunStore.FindAttempt(th.Context, ref)
require.NoError(t, err)
require.NotNil(t, attempt)
status, err := attempt.ReadStatus(th.Context)
require.NoError(t, err)
require.Equal(t, core.Failed, status.Status)
require.Contains(t, status.Error, "retry failed due to lock contention")
})
t.Run("ReturnsErrorForNilDAG", func(t *testing.T) {
th := test.SetupCommand(t)
ctx := &cmd.Context{
Context: th.Context,
Config: th.Config,
DAGRunStore: th.DAGRunStore,
}
err := ctx.RecordEarlyFailure(nil, "some-run-id", errors.New("test error"))
require.Error(t, err)
require.Contains(t, err.Error(), "DAG and dag-run ID are required")
})
t.Run("ReturnsErrorForEmptyDAGRunID", func(t *testing.T) {
th := test.SetupCommand(t)
dag := th.DAG(t, `
steps:
- name: step1
command: echo hello
`)
ctx := &cmd.Context{
Context: th.Context,
Config: th.Config,
DAGRunStore: th.DAGRunStore,
}
err := ctx.RecordEarlyFailure(dag.DAG, "", errors.New("test error"))
require.Error(t, err)
require.Contains(t, err.Error(), "DAG and dag-run ID are required")
})
t.Run("CanRetryEarlyFailureRecord", func(t *testing.T) {
th := test.SetupCommand(t)
dag := th.DAG(t, `
steps:
- name: step1
command: echo hello
`)
dagRunID := "early-failure-retry-test"
testErr := errors.New("initial process acquisition failed")
// Create Context and record early failure
ctx := &cmd.Context{
Context: th.Context,
Config: th.Config,
DAGRunStore: th.DAGRunStore,
}
err := ctx.RecordEarlyFailure(dag.DAG, dagRunID, testErr)
require.NoError(t, err)
// Verify initial failure status
ref := execution.NewDAGRunRef(dag.Name, dagRunID)
attempt, err := th.DAGRunStore.FindAttempt(th.Context, ref)
require.NoError(t, err)
require.NotNil(t, attempt)
status, err := attempt.ReadStatus(th.Context)
require.NoError(t, err)
require.Equal(t, core.Failed, status.Status)
// Verify DAG can be read back (required for retry)
storedDAG, err := attempt.ReadDAG(th.Context)
require.NoError(t, err)
require.NotNil(t, storedDAG)
require.Equal(t, dag.Name, storedDAG.Name)
// Now retry the early failure record
th.RunCommand(t, cmd.Retry(), test.CmdTest{
Args: []string{"retry", "--run-id", dagRunID, dag.Name},
})
// Wait for retry to complete
require.Eventually(t, func() bool {
currentStatus, err := th.DAGRunMgr.GetCurrentStatus(th.Context, dag.DAG, dagRunID)
return err == nil && currentStatus != nil && currentStatus.Status == core.Succeeded
}, 5*time.Second, 100*time.Millisecond, "Retry should succeed")
// Verify final status is succeeded
finalStatus, err := th.DAGRunMgr.GetCurrentStatus(th.Context, dag.DAG, dagRunID)
require.NoError(t, err)
require.Equal(t, core.Succeeded, finalStatus.Status)
})
}

View File

@ -87,9 +87,9 @@ func runRestart(ctx *Context, args []string) error {
return nil
}
func handleRestartProcess(ctx *Context, d *core.DAG, dagRunID string) error {
func handleRestartProcess(ctx *Context, d *core.DAG, oldDagRunID string) error {
// Stop if running
if err := stopDAGIfRunning(ctx, ctx.DAGRunMgr, d, dagRunID); err != nil {
if err := stopDAGIfRunning(ctx, ctx.DAGRunMgr, d, oldDagRunID); err != nil {
return err
}
@ -99,36 +99,40 @@ func handleRestartProcess(ctx *Context, d *core.DAG, dagRunID string) error {
time.Sleep(d.RestartWait)
}
// Generate new dag-run ID for the restart
newDagRunID, err := genRunID()
if err != nil {
return fmt.Errorf("failed to generate dag-run ID: %w", err)
}
// Execute the exact same DAG with the same parameters but a new dag-run ID
if err := ctx.ProcStore.Lock(ctx, d.ProcGroup()); err != nil {
logger.Debug(ctx, "Failed to lock process group", tag.Error(err))
return errMaxRunReached
_ = ctx.RecordEarlyFailure(d, newDagRunID, err)
return errProcAcquisitionFailed
}
defer ctx.ProcStore.Unlock(ctx, d.ProcGroup())
// Acquire process handle
proc, err := ctx.ProcStore.Acquire(ctx, d.ProcGroup(), execution.NewDAGRunRef(d.Name, dagRunID))
proc, err := ctx.ProcStore.Acquire(ctx, d.ProcGroup(), execution.NewDAGRunRef(d.Name, newDagRunID))
if err != nil {
ctx.ProcStore.Unlock(ctx, d.ProcGroup())
logger.Debug(ctx, "Failed to acquire process handle", tag.Error(err))
return fmt.Errorf("failed to acquire process handle: %w", errMaxRunReached)
_ = ctx.RecordEarlyFailure(d, newDagRunID, err)
return fmt.Errorf("failed to acquire process handle: %w", errProcAcquisitionFailed)
}
defer func() {
_ = proc.Stop(ctx)
}()
// Unlock the process group
// Unlock the process group immediately after acquiring the handle
ctx.ProcStore.Unlock(ctx, d.ProcGroup())
return executeDAG(ctx, ctx.DAGRunMgr, d)
return executeDAGWithRunID(ctx, ctx.DAGRunMgr, d, newDagRunID)
}
// It returns an error if run ID generation, log or DAG store initialization, or agent execution fails.
func executeDAG(ctx *Context, cli runtime.Manager, dag *core.DAG) error {
dagRunID, err := genRunID()
if err != nil {
return fmt.Errorf("failed to generate dag-run ID: %w", err)
}
// executeDAGWithRunID executes a DAG with a pre-generated run ID.
// It returns an error if log or DAG store initialization, or agent execution fails.
func executeDAGWithRunID(ctx *Context, cli runtime.Manager, dag *core.DAG, dagRunID string) error {
logFile, err := ctx.OpenLogFile(dag, dagRunID)
if err != nil {
return fmt.Errorf("failed to initialize log file: %w", err)

View File

@ -2,18 +2,14 @@ package cmd
import (
"fmt"
"log/slog"
"path/filepath"
"time"
"github.com/dagu-org/dagu/internal/common/fileutil"
"github.com/dagu-org/dagu/internal/common/logger"
"github.com/dagu-org/dagu/internal/common/logger/tag"
"github.com/dagu-org/dagu/internal/common/stringutil"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/execution"
"github.com/dagu-org/dagu/internal/runtime/agent"
"github.com/dagu-org/dagu/internal/runtime/transform"
"github.com/spf13/cobra"
)
@ -37,12 +33,12 @@ Examples:
)
}
var retryFlags = []commandLineFlag{dagRunIDFlagRetry, stepNameForRetry, disableMaxActiveRuns, noQueueFlag}
var retryFlags = []commandLineFlag{dagRunIDFlagRetry, stepNameForRetry}
func runRetry(ctx *Context, args []string) error {
// Extract retry details
dagRunID, _ := ctx.StringParam("run-id")
stepName, _ := ctx.StringParam("step")
disableMaxActiveRuns := ctx.Command.Flags().Changed("disable-max-active-runs")
name, err := extractDAGName(ctx, args[0])
if err != nil {
@ -71,56 +67,18 @@ func runRetry(ctx *Context, args []string) error {
// Set DAG context for all logs
ctx.Context = logger.WithValues(ctx.Context, tag.DAG(dag.Name), tag.RunID(dagRunID))
// Check if queue is disabled via config or flag
queueDisabled := !ctx.Config.Queues.Enabled || ctx.Command.Flags().Changed("no-queue")
// Check if this DAG should be distributed to workers
// If the DAG has a workerSelector and the queue is not disabled,
// enqueue it so the scheduler can dispatch it to a worker.
// The --no-queue flag acts as a circuit breaker to prevent infinite loops
// when the worker executes the dispatched retry task.
if !queueDisabled && len(dag.WorkerSelector) > 0 {
logger.Info(ctx, "DAG has workerSelector, enqueueing retry for distributed execution", slog.Any("worker-selector", dag.WorkerSelector))
// Enqueue the retry - must create new attempt with status "Queued"
// so the scheduler will process it
if err := enqueueRetry(ctx, dag, dagRunID); err != nil {
return fmt.Errorf("failed to enqueue retry: %w", err)
}
logger.Info(ctx.Context, "Retry enqueued")
return nil
}
// Try lock proc store to avoid race
if err := ctx.ProcStore.Lock(ctx, dag.ProcGroup()); err != nil {
return fmt.Errorf("failed to lock process group: %w", err)
}
defer ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
if !disableMaxActiveRuns {
liveCount, err := ctx.ProcStore.CountAliveByDAGName(ctx, dag.ProcGroup(), dag.Name)
if err != nil {
return fmt.Errorf("failed to access proc store: %w", err)
}
// Count queued DAG-runs and return error if the total of a new run plus
// active runs will exceed the maxActiveRuns.
queuedRuns, err := ctx.QueueStore.ListByDAGName(ctx, dag.ProcGroup(), dag.Name)
if err != nil {
return fmt.Errorf("failed to read queue: %w", err)
}
// If the DAG has a queue configured and maxActiveRuns > 0, ensure the number
// of active runs in the queue does not exceed this limit.
if dag.MaxActiveRuns > 0 && len(queuedRuns)+liveCount >= dag.MaxActiveRuns {
return fmt.Errorf("DAG %s is already in the queue (maxActiveRuns=%d), cannot start", dag.Name, dag.MaxActiveRuns)
}
}
// Acquire process handle
proc, err := ctx.ProcStore.Acquire(ctx, dag.ProcGroup(), execution.NewDAGRunRef(dag.Name, dagRunID))
if err != nil {
ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
logger.Debug(ctx, "Failed to acquire process handle", tag.Error(err))
return fmt.Errorf("failed to acquire process handle: %w", errMaxRunReached)
_ = ctx.RecordEarlyFailure(dag, dagRunID, err)
return fmt.Errorf("failed to acquire process handle: %w", errProcAcquisitionFailed)
}
defer func() {
_ = proc.Stop(ctx)
@ -189,65 +147,3 @@ func executeRetry(ctx *Context, dag *core.DAG, status *execution.DAGRunStatus, r
// Use the shared agent execution function
return ExecuteAgent(ctx, agentInstance, dag, status.DAGRunID, logFile)
}
// enqueueRetry creates a new attempt for retry and enqueues it for execution
func enqueueRetry(ctx *Context, dag *core.DAG, dagRunID string) error {
// Queued dag-runs must not have a location because it is used to generate
// unix pipe. If two DAGs has same location, they can not run at the same time.
// Queued DAGs can be run at the same time depending on the `maxActiveRuns` setting.
dag.Location = ""
// Check if queues are enabled
if !ctx.Config.Queues.Enabled {
return fmt.Errorf("queues are disabled in configuration")
}
// Create a new attempt for retry
att, err := ctx.DAGRunStore.CreateAttempt(ctx.Context, dag, time.Now(), dagRunID, execution.NewDAGRunAttemptOptions{
Retry: true,
})
if err != nil {
return fmt.Errorf("failed to create retry attempt: %w", err)
}
// Generate log file name
logFile, err := ctx.GenLogFileName(dag, dagRunID)
if err != nil {
return fmt.Errorf("failed to generate log file name: %w", err)
}
// Create status for the new attempt with "Queued" status
opts := []transform.StatusOption{
transform.WithLogFilePath(logFile),
transform.WithAttemptID(att.ID()),
transform.WithPreconditions(dag.Preconditions),
transform.WithQueuedAt(stringutil.FormatTime(time.Now())),
transform.WithHierarchyRefs(
execution.NewDAGRunRef(dag.Name, dagRunID),
execution.DAGRunRef{},
),
}
dagStatus := transform.NewStatusBuilder(dag).Create(dagRunID, core.Queued, 0, time.Time{}, opts...)
// Write the status
if err := att.Open(ctx.Context); err != nil {
return fmt.Errorf("failed to open attempt: %w", err)
}
defer func() {
_ = att.Close(ctx.Context)
}()
if err := att.Write(ctx.Context, dagStatus); err != nil {
return fmt.Errorf("failed to save status: %w", err)
}
// Enqueue the DAG run
dagRun := execution.NewDAGRunRef(dag.Name, dagRunID)
if err := ctx.QueueStore.Enqueue(ctx.Context, dag.ProcGroup(), execution.QueuePriorityLow, dagRun); err != nil {
return fmt.Errorf("failed to enqueue: %w", err)
}
logger.Info(ctx, "Retry attempt created and enqueued", tag.AttemptID(att.ID()))
return nil
}

View File

@ -57,7 +57,7 @@ This command parses the DAG definition, resolves parameters, and initiates the D
}
// Command line flags for the start command
var startFlags = []commandLineFlag{paramsFlag, nameFlag, dagRunIDFlag, fromRunIDFlag, parentDAGRunFlag, rootDAGRunFlag, noQueueFlag, disableMaxActiveRuns, defaultWorkingDirFlag}
var startFlags = []commandLineFlag{paramsFlag, nameFlag, dagRunIDFlag, fromRunIDFlag, parentDAGRunFlag, rootDAGRunFlag, defaultWorkingDirFlag}
var fromRunIDFlag = commandLineFlag{
name: "from-run-id",
@ -81,8 +81,6 @@ func runStart(ctx *Context, args []string) error {
return fmt.Errorf("--from-run-id cannot be combined with --parent or --root")
}
disableMaxActiveRuns := ctx.Command.Flags().Changed("disable-max-active-runs")
var (
dag *core.DAG
params string
@ -157,14 +155,6 @@ func runStart(ctx *Context, args []string) error {
return handleSubDAGRun(ctx, dag, dagRunID, params, root, parent)
}
// Check if queue is disabled via config or flag
queueDisabled := !ctx.Config.Queues.Enabled
// check no-queue flag (overrides config)
if ctx.Command.Flags().Changed("no-queue") {
queueDisabled = true
}
// Check if the DAG run-id is unique
attempt, _ := ctx.DAGRunStore.FindAttempt(ctx, root)
if attempt != nil {
@ -172,15 +162,6 @@ func runStart(ctx *Context, args []string) error {
return fmt.Errorf("dag-run ID %s already exists for DAG %s", dagRunID, dag.Name)
}
// Count running DAG to check against maxActiveRuns setting (best effort).
liveCount, err := ctx.ProcStore.CountAliveByDAGName(ctx, dag.ProcGroup(), dag.Name)
if err != nil {
return fmt.Errorf("failed to access proc store: %w", err)
}
if !disableMaxActiveRuns && dag.MaxActiveRuns == 1 && liveCount > 0 {
return fmt.Errorf("DAG %s is already running, cannot start", dag.Name)
}
// Log root dag-run or reschedule action
if fromRunID != "" {
logger.Info(ctx, "Rescheduling dag-run",
@ -191,78 +172,36 @@ func runStart(ctx *Context, args []string) error {
logger.Info(ctx, "Executing root dag-run", slog.String("params", params))
}
// Check if this DAG should be distributed to workers
// If the DAG has a workerSelector and the queue is not disabled,
// enqueue it so the scheduler can dispatch it to a worker.
// The --no-queue flag acts as a circuit breaker to prevent infinite loops
// when the worker executes the dispatched task.
if !queueDisabled && len(dag.WorkerSelector) > 0 {
logger.Info(ctx, "DAG has workerSelector, enqueueing for distributed execution", slog.Any("worker-selector", dag.WorkerSelector))
dag.Location = "" // Queued dag-runs must not have a location
return enqueueDAGRun(ctx, dag, dagRunID)
}
err = tryExecuteDAG(ctx, dag, dagRunID, root, disableMaxActiveRuns)
if errors.Is(err, errMaxRunReached) && !queueDisabled && !disableMaxActiveRuns {
dag.Location = "" // Queued dag-runs must not have a location
// If the DAG has a queue configured and maxActiveRuns > 1, ensure the number
// of active runs in the queue does not exceed this limit.
// The scheduler only enforces maxActiveRuns at the global queue level.
queuedRuns, err := ctx.QueueStore.ListByDAGName(ctx, dag.ProcGroup(), dag.Name)
if err != nil {
return fmt.Errorf("failed to read queue: %w", err)
}
if dag.Queue != "" && dag.MaxActiveRuns > 1 && len(queuedRuns)+liveCount >= dag.MaxActiveRuns {
return fmt.Errorf("DAG %s is already in the queue (maxActiveRuns=%d), cannot start", dag.Name, dag.MaxActiveRuns)
}
// Enqueue the DAG-run for execution
return enqueueDAGRun(ctx, dag, dagRunID)
}
return err // return executed result
return tryExecuteDAG(ctx, dag, dagRunID, root)
}
var (
errMaxRunReached = errors.New("max run reached")
errProcAcquisitionFailed = errors.New("failed to acquire process handle")
)
// tryExecuteDAG tries to run the DAG within the max concurrent run config
func tryExecuteDAG(ctx *Context, dag *core.DAG, dagRunID string, root execution.DAGRunRef, disableMaxActiveRuns bool) error {
// tryExecuteDAG tries to run the DAG.
func tryExecuteDAG(ctx *Context, dag *core.DAG, dagRunID string, root execution.DAGRunRef) error {
if err := ctx.ProcStore.Lock(ctx, dag.ProcGroup()); err != nil {
logger.Debug(ctx, "Failed to lock process group", tag.Error(err))
return errMaxRunReached
}
defer ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
if !disableMaxActiveRuns {
runningCount, err := ctx.ProcStore.CountAlive(ctx, dag.ProcGroup())
if err != nil {
logger.Debug(ctx, "Failed to count live processes", tag.Error(err))
return fmt.Errorf("failed to count live process for %s: %w", dag.ProcGroup(), errMaxRunReached)
}
// If the DAG has a queue configured and maxActiveRuns > 0, ensure the number
// of active runs in the queue does not exceed this limit.
if dag.MaxActiveRuns > 0 && runningCount >= dag.MaxActiveRuns {
// It's not possible to run right now.
return fmt.Errorf("max active run is reached (%d >= %d): %w", runningCount, dag.MaxActiveRuns, errMaxRunReached)
}
_ = ctx.RecordEarlyFailure(dag, dagRunID, err)
return errProcAcquisitionFailed
}
// Acquire process handle
proc, err := ctx.ProcStore.Acquire(ctx, dag.ProcGroup(), execution.NewDAGRunRef(dag.Name, dagRunID))
if err != nil {
ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
logger.Debug(ctx, "Failed to acquire process handle", tag.Error(err))
return fmt.Errorf("failed to acquire process handle: %w", errMaxRunReached)
_ = ctx.RecordEarlyFailure(dag, dagRunID, err)
return fmt.Errorf("failed to acquire process handle: %w", errProcAcquisitionFailed)
}
defer func() {
_ = proc.Stop(ctx)
}()
ctx.Proc = proc
// Unlock the process group
// Unlock the process group immediately after acquiring the handle
// to allow other instances of the same DAG to start.
ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
return executeDAGRun(ctx, dag, execution.DAGRunRef{}, dagRunID, root)

View File

@ -10,7 +10,6 @@ import (
"github.com/dagu-org/dagu/internal/common/logger"
"github.com/dagu-org/dagu/internal/common/logger/tag"
"github.com/dagu-org/dagu/internal/service/coordinator"
"github.com/dagu-org/dagu/internal/service/resource"
"github.com/spf13/cobra"
)
@ -103,16 +102,9 @@ func runStartAll(ctx *Context, _ []string) error {
}
// Only start coordinator if not bound to localhost
var coordinator *coordinator.Service
enableCoordinator := ctx.Config.Coordinator.Host != "127.0.0.1" && ctx.Config.Coordinator.Host != "localhost" && ctx.Config.Coordinator.Host != "::1"
if enableCoordinator {
coordinator, err = newCoordinator(ctx, ctx.Config, ctx.ServiceRegistry)
if err != nil {
return fmt.Errorf("failed to initialize coordinator: %w", err)
}
} else {
logger.Info(ctx, "Coordinator disabled (bound to localhost), set --coordinator.host and --coordinator.advertise to enable distributed mode")
coordinator, err := newCoordinator(ctx, ctx.Config, ctx.ServiceRegistry)
if err != nil {
return fmt.Errorf("failed to initialize coordinator: %w", err)
}
// Create a new context with the signal context for services
@ -136,10 +128,7 @@ func runStartAll(ctx *Context, _ []string) error {
// WaitGroup to track all services
var wg sync.WaitGroup
serviceCount := 2 // scheduler + server
if enableCoordinator {
serviceCount = 3 // + coordinator
}
serviceCount := 3 // scheduler + server + coordinator
errCh := make(chan error, serviceCount)
// Start scheduler
@ -157,19 +146,16 @@ func runStartAll(ctx *Context, _ []string) error {
}
}()
// Start coordinator (if enabled)
if enableCoordinator {
wg.Add(1)
go func() {
defer wg.Done()
if err := coordinator.Start(serviceCtx); err != nil {
select {
case errCh <- fmt.Errorf("coordinator failed: %w", err):
default:
}
wg.Add(1)
go func() {
defer wg.Done()
if err := coordinator.Start(serviceCtx); err != nil {
select {
case errCh <- fmt.Errorf("coordinator failed: %w", err):
default:
}
}()
}
}
}()
// Start server
wg.Add(1)
@ -203,13 +189,11 @@ func runStartAll(ctx *Context, _ []string) error {
// Stop all services gracefully
logger.Info(ctx, "Stopping all services")
// Stop coordinator first to unregister from service registry (if it was started)
if enableCoordinator && coordinator != nil {
if err := coordinator.Stop(ctx); err != nil {
logger.Error(ctx, "Failed to stop coordinator",
tag.Error(err),
)
}
// Stop coordinator first to unregister from service registry
if err := coordinator.Stop(ctx); err != nil {
logger.Error(ctx, "Failed to stop coordinator",
tag.Error(err),
)
}
// Stop resource service

View File

@ -224,6 +224,8 @@ func TestRetrier_Next(t *testing.T) {
})
}
var _ RetryPolicy = (*mockRetryPolicy)(nil)
// mockRetryPolicy is a test helper that returns predefined intervals
type mockRetryPolicy struct {
intervals []time.Duration

View File

@ -266,6 +266,14 @@ func processStructFields(ctx context.Context, v reflect.Value, opts *EvalOptions
// nolint:exhaustive
switch field.Kind() {
case reflect.Ptr:
if field.IsNil() {
continue
}
if err := processPointerField(ctx, field, opts); err != nil {
return err
}
case reflect.String:
value := field.String()
@ -302,6 +310,137 @@ func processStructFields(ctx context.Context, v reflect.Value, opts *EvalOptions
}
field.Set(processed)
case reflect.Slice, reflect.Array:
if field.IsNil() {
continue
}
// Create a new slice with new backing array to avoid mutating original
newSlice := reflect.MakeSlice(field.Type(), field.Len(), field.Cap())
reflect.Copy(newSlice, field)
if err := processSliceWithOpts(ctx, newSlice, opts); err != nil {
return err
}
field.Set(newSlice)
}
}
return nil
}
func processPointerField(ctx context.Context, field reflect.Value, opts *EvalOptions) error {
elem := field.Elem()
if !elem.CanSet() {
return nil
}
// nolint:exhaustive
switch elem.Kind() {
case reflect.String:
value := elem.String()
value = expandVariables(ctx, value, opts)
if opts.Substitute {
var err error
value, err = substituteCommands(value)
if err != nil {
return err
}
}
if opts.ExpandEnv {
value = os.ExpandEnv(value)
}
// Create a new string and update the pointer to point to it
// This avoids mutating the original value
newStr := reflect.New(elem.Type())
newStr.Elem().SetString(value)
field.Set(newStr)
case reflect.Struct:
// Create a copy of the struct to avoid mutating the original
newStruct := reflect.New(elem.Type())
newStruct.Elem().Set(elem)
if err := processStructFields(ctx, newStruct.Elem(), opts); err != nil {
return err
}
field.Set(newStruct)
case reflect.Map:
if elem.IsNil() {
return nil
}
processed, err := processMapWithOpts(ctx, elem, opts)
if err != nil {
return err
}
// Create a new pointer to the processed map
newMap := reflect.New(elem.Type())
newMap.Elem().Set(processed)
field.Set(newMap)
case reflect.Slice, reflect.Array:
if elem.IsNil() {
return nil
}
// Create a new slice with new backing array to avoid mutating original
newSlice := reflect.MakeSlice(elem.Type(), elem.Len(), elem.Cap())
reflect.Copy(newSlice, elem)
if err := processSliceWithOpts(ctx, newSlice, opts); err != nil {
return err
}
// Create new pointer to the new slice
newPtr := reflect.New(elem.Type())
newPtr.Elem().Set(newSlice)
field.Set(newPtr)
}
return nil
}
func processSliceWithOpts(ctx context.Context, v reflect.Value, opts *EvalOptions) error {
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if !elem.CanSet() {
continue
}
// nolint:exhaustive
switch elem.Kind() {
case reflect.String:
value := elem.String()
value = expandVariables(ctx, value, opts)
if opts.Substitute {
var err error
value, err = substituteCommands(value)
if err != nil {
return err
}
}
if opts.ExpandEnv {
value = os.ExpandEnv(value)
}
elem.SetString(value)
case reflect.Struct:
if err := processStructFields(ctx, elem, opts); err != nil {
return err
}
case reflect.Map:
if elem.IsNil() {
continue
}
processed, err := processMapWithOpts(ctx, elem, opts)
if err != nil {
return err
}
elem.Set(processed)
case reflect.Ptr:
if elem.IsNil() {
continue
}
if err := processPointerField(ctx, elem, opts); err != nil {
return err
}
}
}
return nil

View File

@ -3,7 +3,6 @@ package cmdutil
import (
"context"
"os"
"reflect"
"strconv"
"testing"
@ -85,13 +84,12 @@ func TestEvalStringFields(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
got, err := EvalStringFields(ctx, tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("SubstituteStringFields() error = %v, wantErr %v", err, tt.wantErr)
if tt.wantErr {
assert.Error(t, err)
return
}
if !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
t.Errorf("SubstituteStringFields() = %+v, want %+v", got, tt.want)
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
@ -107,12 +105,10 @@ func TestEvalStringFields_AnonymousStruct(t *testing.T) {
require.Equal(t, "hello", obj.Field)
}
func TestSubstituteStringFields_NonStruct(t *testing.T) {
func TestEvalStringFields_NonStruct(t *testing.T) {
ctx := context.Background()
_, err := EvalStringFields(ctx, "not a struct")
if err == nil {
t.Error("SubstituteStringFields() should return error for non-struct input")
}
assert.Error(t, err)
}
func TestEvalStringFields_NestedStructs(t *testing.T) {
@ -160,13 +156,8 @@ func TestEvalStringFields_NestedStructs(t *testing.T) {
ctx := context.Background()
got, err := EvalStringFields(ctx, input)
if err != nil {
t.Fatalf("SubstituteStringFields() error = %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("SubstituteStringFields() = %+v, want %+v", got, want)
}
require.NoError(t, err)
assert.Equal(t, want, got)
}
func TestEvalStringFields_EmptyStruct(t *testing.T) {
@ -175,13 +166,541 @@ func TestEvalStringFields_EmptyStruct(t *testing.T) {
input := Empty{}
ctx := context.Background()
got, err := EvalStringFields(ctx, input)
if err != nil {
t.Fatalf("SubstituteStringFields() error = %v", err)
require.NoError(t, err)
assert.Equal(t, input, got)
}
func TestEvalStringFields_PointerFields(t *testing.T) {
_ = os.Setenv("PTR_VAR", "pointer_value")
defer func() {
_ = os.Unsetenv("PTR_VAR")
}()
type PointerNested struct {
Value string
}
if !reflect.DeepEqual(got, input) {
t.Errorf("SubstituteStringFields() = %+v, want %+v", got, input)
type PointerStruct struct {
Token *string
Nested *PointerNested
Items []*PointerNested
}
ctx := context.Background()
token := "$PTR_VAR"
input := PointerStruct{
Token: &token,
Nested: &PointerNested{Value: "${PTR_VAR}"},
Items: []*PointerNested{{
Value: "$PTR_VAR",
}},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.NotNil(t, result.Token)
assert.Equal(t, "pointer_value", *result.Token)
require.NotNil(t, result.Nested)
assert.Equal(t, "pointer_value", result.Nested.Value)
require.Len(t, result.Items, 1)
assert.Equal(t, "pointer_value", result.Items[0].Value)
}
func TestEvalStringFields_NilPointerFields(t *testing.T) {
t.Setenv("NIL_TEST_VAR", "value")
type Nested struct {
Value string
}
type StructWithNilPointers struct {
NilString *string
NilStruct *Nested
NilMap *map[string]string
NilSlice *[]string
// Non-nil field to verify processing continues
Regular string
}
ctx := context.Background()
input := StructWithNilPointers{
NilString: nil,
NilStruct: nil,
NilMap: nil,
NilSlice: nil,
Regular: "$NIL_TEST_VAR",
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
assert.Nil(t, result.NilString)
assert.Nil(t, result.NilStruct)
assert.Nil(t, result.NilMap)
assert.Nil(t, result.NilSlice)
assert.Equal(t, "value", result.Regular)
}
func TestEvalStringFields_PointerToMap(t *testing.T) {
t.Setenv("MAP_PTR_VAR", "map_ptr_value")
type StructWithPtrMap struct {
Config *map[string]string
}
ctx := context.Background()
t.Run("PointerToMapWithEnvVars", func(t *testing.T) {
mapVal := map[string]string{
"key1": "$MAP_PTR_VAR",
"key2": "${MAP_PTR_VAR}",
}
input := StructWithPtrMap{Config: &mapVal}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.NotNil(t, result.Config)
assert.Equal(t, "map_ptr_value", (*result.Config)["key1"])
assert.Equal(t, "map_ptr_value", (*result.Config)["key2"])
})
t.Run("PointerToNilMap", func(t *testing.T) {
input := StructWithPtrMap{Config: nil}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
assert.Nil(t, result.Config)
})
}
func TestEvalStringFields_PointerToSlice(t *testing.T) {
t.Setenv("SLICE_PTR_VAR", "slice_ptr_value")
type Nested struct {
Value string
}
type StructWithPtrSlice struct {
Items *[]string
Structs *[]*Nested
}
ctx := context.Background()
t.Run("PointerToSliceOfStrings", func(t *testing.T) {
items := []string{"$SLICE_PTR_VAR", "${SLICE_PTR_VAR}"}
input := StructWithPtrSlice{Items: &items}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.NotNil(t, result.Items)
require.Len(t, *result.Items, 2)
assert.Equal(t, "slice_ptr_value", (*result.Items)[0])
assert.Equal(t, "slice_ptr_value", (*result.Items)[1])
})
t.Run("PointerToSliceOfStructPointers", func(t *testing.T) {
structs := []*Nested{
{Value: "$SLICE_PTR_VAR"},
{Value: "${SLICE_PTR_VAR}"},
}
input := StructWithPtrSlice{Structs: &structs}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.NotNil(t, result.Structs)
require.Len(t, *result.Structs, 2)
assert.Equal(t, "slice_ptr_value", (*result.Structs)[0].Value)
assert.Equal(t, "slice_ptr_value", (*result.Structs)[1].Value)
})
t.Run("PointerToNilSlice", func(t *testing.T) {
input := StructWithPtrSlice{Items: nil}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
assert.Nil(t, result.Items)
})
}
func TestEvalStringFields_SliceOfStrings(t *testing.T) {
t.Setenv("SLICE_STR_VAR", "slice_str_value")
type StructWithStringSlice struct {
Tags []string
Labels []string
}
ctx := context.Background()
t.Run("SliceWithEnvVars", func(t *testing.T) {
input := StructWithStringSlice{
Tags: []string{"$SLICE_STR_VAR", "${SLICE_STR_VAR}", "plain"},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.Len(t, result.Tags, 3)
assert.Equal(t, "slice_str_value", result.Tags[0])
assert.Equal(t, "slice_str_value", result.Tags[1])
assert.Equal(t, "plain", result.Tags[2])
})
t.Run("EmptySlice", func(t *testing.T) {
input := StructWithStringSlice{
Tags: []string{},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
assert.Empty(t, result.Tags)
})
t.Run("NilSlice", func(t *testing.T) {
input := StructWithStringSlice{
Tags: nil,
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
assert.Nil(t, result.Tags)
})
t.Run("SliceWithCommandSubstitution", func(t *testing.T) {
input := StructWithStringSlice{
Tags: []string{"`echo hello`", "`echo world`"},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.Len(t, result.Tags, 2)
assert.Equal(t, "hello", result.Tags[0])
assert.Equal(t, "world", result.Tags[1])
})
}
func TestEvalStringFields_SliceOfStructs(t *testing.T) {
t.Setenv("STRUCT_SLICE_VAR", "struct_slice_value")
type Item struct {
Name string
Value string
}
type StructWithStructSlice struct {
Items []Item
}
ctx := context.Background()
t.Run("SliceOfStructsWithEnvVars", func(t *testing.T) {
input := StructWithStructSlice{
Items: []Item{
{Name: "$STRUCT_SLICE_VAR", Value: "plain"},
{Name: "plain", Value: "${STRUCT_SLICE_VAR}"},
},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.Len(t, result.Items, 2)
assert.Equal(t, "struct_slice_value", result.Items[0].Name)
assert.Equal(t, "plain", result.Items[0].Value)
assert.Equal(t, "plain", result.Items[1].Name)
assert.Equal(t, "struct_slice_value", result.Items[1].Value)
})
t.Run("EmptySliceOfStructs", func(t *testing.T) {
input := StructWithStructSlice{
Items: []Item{},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
assert.Empty(t, result.Items)
})
}
func TestEvalStringFields_SliceOfMaps(t *testing.T) {
t.Setenv("MAP_SLICE_VAR", "map_slice_value")
type StructWithMapSlice struct {
Configs []map[string]string
}
ctx := context.Background()
t.Run("SliceOfMapsWithEnvVars", func(t *testing.T) {
input := StructWithMapSlice{
Configs: []map[string]string{
{"key": "$MAP_SLICE_VAR"},
{"key": "${MAP_SLICE_VAR}"},
},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.Len(t, result.Configs, 2)
assert.Equal(t, "map_slice_value", result.Configs[0]["key"])
assert.Equal(t, "map_slice_value", result.Configs[1]["key"])
})
t.Run("SliceWithNilMapElement", func(t *testing.T) {
input := StructWithMapSlice{
Configs: []map[string]string{
{"key": "$MAP_SLICE_VAR"},
nil,
{"key": "plain"},
},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.Len(t, result.Configs, 3)
assert.Equal(t, "map_slice_value", result.Configs[0]["key"])
assert.Nil(t, result.Configs[1])
assert.Equal(t, "plain", result.Configs[2]["key"])
})
}
func TestEvalStringFields_SliceWithNilPointers(t *testing.T) {
t.Setenv("NIL_PTR_SLICE_VAR", "nil_ptr_slice_value")
type Nested struct {
Value string
}
type StructWithPointerSlice struct {
StringPtrs []*string
StructPtrs []*Nested
}
ctx := context.Background()
t.Run("SliceWithMixedNilAndNonNilStringPointers", func(t *testing.T) {
val1 := "$NIL_PTR_SLICE_VAR"
val2 := "${NIL_PTR_SLICE_VAR}"
input := StructWithPointerSlice{
StringPtrs: []*string{&val1, nil, &val2},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.Len(t, result.StringPtrs, 3)
require.NotNil(t, result.StringPtrs[0])
assert.Equal(t, "nil_ptr_slice_value", *result.StringPtrs[0])
assert.Nil(t, result.StringPtrs[1])
require.NotNil(t, result.StringPtrs[2])
assert.Equal(t, "nil_ptr_slice_value", *result.StringPtrs[2])
})
t.Run("SliceWithMixedNilAndNonNilStructPointers", func(t *testing.T) {
input := StructWithPointerSlice{
StructPtrs: []*Nested{
{Value: "$NIL_PTR_SLICE_VAR"},
nil,
{Value: "${NIL_PTR_SLICE_VAR}"},
},
}
result, err := EvalStringFields(ctx, input)
require.NoError(t, err)
require.Len(t, result.StructPtrs, 3)
require.NotNil(t, result.StructPtrs[0])
assert.Equal(t, "nil_ptr_slice_value", result.StructPtrs[0].Value)
assert.Nil(t, result.StructPtrs[1])
require.NotNil(t, result.StructPtrs[2])
assert.Equal(t, "nil_ptr_slice_value", result.StructPtrs[2].Value)
})
}
func TestEvalStringFields_PointerMutation(t *testing.T) {
t.Setenv("MUT_VAR", "expanded")
t.Run("PointerToString", func(t *testing.T) {
type S struct {
Token *string
}
original := "$MUT_VAR"
input := S{Token: &original}
result, err := EvalStringFields(context.Background(), input)
require.NoError(t, err)
// Check result is correct
assert.Equal(t, "expanded", *result.Token)
// Check original was NOT mutated
assert.Equal(t, "$MUT_VAR", original, "BUG: original variable was mutated")
assert.Equal(t, "$MUT_VAR", *input.Token, "BUG: input struct's pointer target was mutated")
})
t.Run("SliceOfStringPointers", func(t *testing.T) {
type S struct {
Items []*string
}
val1 := "$MUT_VAR"
val2 := "${MUT_VAR}"
input := S{Items: []*string{&val1, &val2}}
result, err := EvalStringFields(context.Background(), input)
require.NoError(t, err)
// Check result is correct
require.Len(t, result.Items, 2)
assert.Equal(t, "expanded", *result.Items[0])
assert.Equal(t, "expanded", *result.Items[1])
// Check originals were NOT mutated
assert.Equal(t, "$MUT_VAR", val1, "BUG: val1 was mutated")
assert.Equal(t, "${MUT_VAR}", val2, "BUG: val2 was mutated")
assert.Equal(t, "$MUT_VAR", *input.Items[0], "BUG: input.Items[0] target was mutated")
assert.Equal(t, "${MUT_VAR}", *input.Items[1], "BUG: input.Items[1] target was mutated")
})
t.Run("PointerToStruct", func(t *testing.T) {
type Nested struct {
Value string
}
type S struct {
Nested *Nested
}
input := S{Nested: &Nested{Value: "$MUT_VAR"}}
originalValue := input.Nested.Value
result, err := EvalStringFields(context.Background(), input)
require.NoError(t, err)
// Check result is correct
assert.Equal(t, "expanded", result.Nested.Value)
// Check original was NOT mutated
assert.Equal(t, "$MUT_VAR", originalValue, "BUG: original nested value was mutated")
assert.Equal(t, "$MUT_VAR", input.Nested.Value, "BUG: input.Nested.Value was mutated")
})
t.Run("PointerToSlice", func(t *testing.T) {
type S struct {
Items *[]string
}
items := []string{"$MUT_VAR", "${MUT_VAR}"}
input := S{Items: &items}
result, err := EvalStringFields(context.Background(), input)
require.NoError(t, err)
// Check result is correct
require.NotNil(t, result.Items)
assert.Equal(t, "expanded", (*result.Items)[0])
assert.Equal(t, "expanded", (*result.Items)[1])
// Check original was NOT mutated
assert.Equal(t, "$MUT_VAR", items[0], "BUG: items[0] was mutated")
assert.Equal(t, "${MUT_VAR}", items[1], "BUG: items[1] was mutated")
})
}
func TestEvalStringFields_PointerFieldErrors(t *testing.T) {
type Nested struct {
Command string
}
type StructWithPointerErrors struct {
StringPtr *string
StructPtr *Nested
MapPtr *map[string]string
}
ctx := context.Background()
t.Run("PointerToStringWithInvalidCommand", func(t *testing.T) {
invalidCmd := "`invalid_command_xyz123`"
input := StructWithPointerErrors{StringPtr: &invalidCmd}
_, err := EvalStringFields(ctx, input)
assert.Error(t, err)
})
t.Run("PointerToStructWithInvalidNestedCommand", func(t *testing.T) {
input := StructWithPointerErrors{
StructPtr: &Nested{Command: "`invalid_command_xyz123`"},
}
_, err := EvalStringFields(ctx, input)
assert.Error(t, err)
})
t.Run("PointerToMapWithInvalidCommand", func(t *testing.T) {
mapVal := map[string]string{
"key": "`invalid_command_xyz123`",
}
input := StructWithPointerErrors{MapPtr: &mapVal}
_, err := EvalStringFields(ctx, input)
assert.Error(t, err)
})
}
func TestEvalStringFields_SliceFieldErrors(t *testing.T) {
type Nested struct {
Command string
}
type StructWithSliceErrors struct {
Strings []string
Structs []Nested
StringPtrs []*string
Maps []map[string]string
}
ctx := context.Background()
t.Run("SliceStringWithInvalidCommand", func(t *testing.T) {
input := StructWithSliceErrors{
Strings: []string{"valid", "`invalid_command_xyz123`"},
}
_, err := EvalStringFields(ctx, input)
assert.Error(t, err)
})
t.Run("SliceStructWithInvalidNestedCommand", func(t *testing.T) {
input := StructWithSliceErrors{
Structs: []Nested{
{Command: "valid"},
{Command: "`invalid_command_xyz123`"},
},
}
_, err := EvalStringFields(ctx, input)
assert.Error(t, err)
})
t.Run("SlicePointerWithInvalidCommand", func(t *testing.T) {
valid := "valid"
invalid := "`invalid_command_xyz123`"
input := StructWithSliceErrors{
StringPtrs: []*string{&valid, &invalid},
}
_, err := EvalStringFields(ctx, input)
assert.Error(t, err)
})
t.Run("SliceMapWithInvalidCommand", func(t *testing.T) {
input := StructWithSliceErrors{
Maps: []map[string]string{
{"key": "valid"},
{"key": "`invalid_command_xyz123`"},
},
}
_, err := EvalStringFields(ctx, input)
assert.Error(t, err)
})
}
func TestReplaceVars(t *testing.T) {

View File

@ -0,0 +1,43 @@
package cmdutil
import (
"strings"
)
// ShellQuote escapes a string for use in a shell command.
func ShellQuote(s string) string {
if s == "" {
return "''"
}
// Use a conservative set of safe characters:
// Alphanumeric, hyphen, underscore, dot, and slash.
// We only consider ASCII alphanumeric as safe to avoid locale-dependent behavior.
safe := true
for i := 0; i < len(s); i++ {
b := s[i]
if (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') ||
b == '-' || b == '_' || b == '.' || b == '/' {
continue
}
safe = false
break
}
if safe {
return s
}
// Wrap in single quotes and escape any internal single quotes.
// This is the most robust way to escape for POSIX-compliant shells.
// 'user's file' -> 'user'\''s file'
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
}
// ShellQuoteArgs escapes a slice of strings for use in a shell command.
func ShellQuoteArgs(args []string) string {
quoted := make([]string, len(args))
for i, arg := range args {
quoted[i] = ShellQuote(arg)
}
return strings.Join(quoted, " ")
}

View File

@ -0,0 +1,265 @@
package cmdutil
import (
"os/exec"
"testing"
"github.com/stretchr/testify/assert"
)
func TestShellQuote(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Empty string",
input: "",
expected: "''",
},
{
name: "Safe alphanumeric",
input: "abcXYZ123",
expected: "abcXYZ123",
},
{
name: "Safe special characters",
input: "-_./",
expected: "-_./",
},
{
name: "String with space",
input: "hello world",
expected: "'hello world'",
},
{
name: "String with single quote",
input: "user's file",
expected: "'user'\\''s file'",
},
{
name: "String with multiple single quotes",
input: "a'b'c",
expected: "'a'\\''b'\\''c'",
},
{
name: "String with double quote",
input: `"quoted"`,
expected: `'"quoted"'`,
},
{
name: "String with dollar sign",
input: "$VAR",
expected: "'$VAR'",
},
{
name: "String with asterisk",
input: "*.txt",
expected: "'*.txt'",
},
{
name: "String with backtick",
input: "`date`",
expected: "'`date`'",
},
{
name: "String with semicolon",
input: "ls; rm -rf /",
expected: "'ls; rm -rf /'",
},
{
name: "String with ampersand",
input: "run &",
expected: "'run &'",
},
{
name: "String with pipe",
input: "a | b",
expected: "'a | b'",
},
{
name: "String with parentheses",
input: "(subshell)",
expected: "'(subshell)'",
},
{
name: "String with brackets",
input: "[abc]",
expected: "'[abc]'",
},
{
name: "String with braces",
input: "{1..10}",
expected: "'{1..10}'",
},
{
name: "String with redirection",
input: "> output.txt",
expected: "'> output.txt'",
},
{
name: "String with backslash",
input: "path\\to\\file",
expected: "'path\\to\\file'",
},
{
name: "String with newline",
input: "line1\nline2",
expected: "'line1\nline2'",
},
{
name: "String with tab",
input: "field1\tfield2",
expected: "'field1\tfield2'",
},
{
name: "Unicode string",
input: "Hello 世界",
expected: "'Hello 世界'",
},
{
name: "Mixed single and double quotes",
input: "It's a \"test\"",
expected: "'It'\\''s a \"test\"'",
},
{
name: "Only single quotes",
input: "'''",
expected: "''\\'''\\'''\\'''",
},
{
name: "Backslashes and quotes",
input: `\"'\`,
expected: `'\"'\''\'`,
},
{
name: "Non-printable characters",
input: "\x01\x02\x03",
expected: "'\x01\x02\x03'",
},
{
name: "Terminal escape sequence",
input: "\x1b[31mRed\x1b[0m",
expected: "'\x1b[31mRed\x1b[0m'",
},
{
name: "Ultra nasty mixed string",
input: `'"` + "$; \\ \t\n\r\v\f!#%^&*()[]{}|<>?~",
expected: `''\''"` + "$; \\ \t\n\r\v\f!#%^&*()[]{}|<>?~'",
},
{
name: "Leading/trailing spaces",
input: " spaced ",
expected: "' spaced '",
},
{
name: "Command injection attempt",
input: "; rm -rf / ; #",
expected: "'; rm -rf / ; #'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := ShellQuote(tt.input)
assert.Equal(t, tt.expected, actual, "Input: %s", tt.input)
})
}
}
func TestShellQuote_ShellRoundTrip(t *testing.T) {
if _, err := exec.LookPath("sh"); err != nil {
t.Skip("sh not found in PATH")
}
inputs := []string{
"simple",
"with space",
"with 'single' quote",
"with \"double\" quote",
"with $dollar",
"with `backtick`",
"with \\backslash",
"with \nnewline",
"with \ttab",
"with 世界 (unicode)",
"with mixed \"' $`\\ \n\t chars",
"",
"'-'",
"\"-\"",
`'"` + "$; \\ \t\n\r\v\f!#%^&*()[]{}|<>?~", // Ultra nasty
" leading and trailing spaces ",
"!!!@@@###$$$%%%^^^&&&***((()))_++==--",
}
for _, input := range inputs {
t.Run(input, func(t *testing.T) {
quoted := ShellQuote(input)
// Run sh -c 'printf %s <quoted>' and capture output
// We use printf because echo might interpret sequences or add newlines
cmd := exec.Command("sh", "-c", "printf %s "+quoted)
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("sh failed for input %q: %v\nOutput: %s", input, err, string(output))
}
assert.Equal(t, input, string(output), "Round-trip failed for input %q", input)
})
}
}
func TestShellQuoteArgs(t *testing.T) {
tests := []struct {
name string
args []string
expected string
}{
{
name: "No args",
args: []string{},
expected: "",
},
{
name: "Single safe arg",
args: []string{"ls"},
expected: "ls",
},
{
name: "Single unsafe arg",
args: []string{"ls -l"},
expected: "'ls -l'",
},
{
name: "Multiple args",
args: []string{"ls", "-l", "my file.txt"},
expected: "ls -l 'my file.txt'",
},
{
name: "Complex args",
args: []string{"echo", "It's a beautiful day", "$HOME"},
expected: "echo 'It'\\''s a beautiful day' '$HOME'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := ShellQuoteArgs(tt.args)
assert.Equal(t, tt.expected, actual)
})
}
}
func TestShellQuote_RoundTrip(t *testing.T) {
// Exhaustive list of characters to test
chars := ""
for i := 0; i < 256; i++ {
chars += string(rune(i))
}
quoted := ShellQuote(chars)
// We don't have a parser here to verify, but we can at least ensure it's not empty and wrapped if needed.
assert.NotEmpty(t, quoted)
if len(chars) > 0 {
assert.True(t, len(quoted) >= len(chars))
}
}

View File

@ -201,6 +201,7 @@ type PathsConfig struct {
ProcDir string
ServiceRegistryDir string // Directory for service registry files
UsersDir string // Directory for user data (builtin auth)
APIKeysDir string // Directory for API key data (builtin auth)
ConfigFileUsed string // Path to the configuration file used to load settings
}

View File

@ -337,4 +337,5 @@ func TestConfig_Validate(t *testing.T) {
err := cfg.Validate()
require.NoError(t, err)
})
}

View File

@ -217,6 +217,7 @@ type PathsDef struct {
ProcDir string `mapstructure:"procDir"`
ServiceRegistryDir string `mapstructure:"serviceRegistryDir"`
UsersDir string `mapstructure:"usersDir"`
APIKeysDir string `mapstructure:"apiKeysDir"`
}
// UIDef holds the user interface configuration settings.

View File

@ -1,6 +1,7 @@
package config
import (
"context"
"fmt"
"log"
"os"
@ -11,6 +12,7 @@ import (
"time"
"github.com/adrg/xdg"
"github.com/dagu-org/dagu/internal/common/cmdutil"
"github.com/dagu-org/dagu/internal/common/fileutil"
"github.com/spf13/viper"
)
@ -172,6 +174,12 @@ func (l *ConfigLoader) Load() (*Config, error) {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
expandedDef, err := cmdutil.EvalStringFields(context.Background(), def, cmdutil.WithoutSubstitute())
if err != nil {
return nil, fmt.Errorf("failed to expand config variables: %w", err)
}
def = expandedDef
// Build the final Config from the definition (including legacy fields and validations).
cfg, err := l.buildConfig(def)
if err != nil {
@ -290,6 +298,7 @@ func (l *ConfigLoader) loadPathsConfig(cfg *Config, def Definition) {
cfg.Paths.ProcDir = fileutil.ResolvePathOrBlank(def.Paths.ProcDir)
cfg.Paths.ServiceRegistryDir = fileutil.ResolvePathOrBlank(def.Paths.ServiceRegistryDir)
cfg.Paths.UsersDir = fileutil.ResolvePathOrBlank(def.Paths.UsersDir)
cfg.Paths.APIKeysDir = fileutil.ResolvePathOrBlank(def.Paths.APIKeysDir)
}
}
@ -392,6 +401,13 @@ func (l *ConfigLoader) loadServerConfig(cfg *Config, def Definition) {
l.warnings = append(l.warnings, fmt.Sprintf("Auth mode auto-detected as 'oidc' based on OIDC configuration (issuer: %s)", oidc.Issuer))
}
}
// Warn if basic auth is configured with builtin auth mode (it will be ignored)
if cfg.Server.Auth.Mode == AuthModeBuiltin {
if cfg.Server.Auth.Basic.Username != "" || cfg.Server.Auth.Basic.Password != "" {
l.warnings = append(l.warnings, "Basic auth configuration is ignored when auth mode is 'builtin'; use builtin auth's admin credentials instead")
}
}
}
// Set default token TTL if not specified
@ -613,6 +629,9 @@ func (l *ConfigLoader) finalizePaths(cfg *Config) {
if cfg.Paths.UsersDir == "" {
cfg.Paths.UsersDir = filepath.Join(cfg.Paths.DataDir, "users")
}
if cfg.Paths.APIKeysDir == "" {
cfg.Paths.APIKeysDir = filepath.Join(cfg.Paths.DataDir, "apikeys")
}
if cfg.Paths.Executable == "" {
if executable, err := os.Executable(); err == nil {

View File

@ -200,7 +200,8 @@ func TestLoad_Env(t *testing.T) {
ProcDir: filepath.Join(testPaths, "proc"),
QueueDir: filepath.Join(testPaths, "queue"),
ServiceRegistryDir: filepath.Join(testPaths, "service-registry"),
UsersDir: filepath.Join(testPaths, "data", "users"), // Derived from DataDir
UsersDir: filepath.Join(testPaths, "data", "users"), // Derived from DataDir
APIKeysDir: filepath.Join(testPaths, "data", "apikeys"), // Derived from DataDir
},
UI: UI{
LogEncodingCharset: "iso-8859-1",
@ -432,6 +433,7 @@ scheduler:
QueueDir: "/var/dagu/data/queue",
ServiceRegistryDir: "/var/dagu/data/service-registry",
UsersDir: "/var/dagu/data/users",
APIKeysDir: "/var/dagu/data/apikeys",
},
UI: UI{
LogEncodingCharset: "iso-8859-1",
@ -606,6 +608,26 @@ scheduler:
assert.Contains(t, cfg.Warnings[1], "Invalid scheduler.lockRetryInterval")
assert.Contains(t, cfg.Warnings[2], "Invalid scheduler.zombieDetectionInterval")
})
t.Run("BuiltinAuthWithBasicAuthWarning", func(t *testing.T) {
t.Parallel()
cfg := loadWithEnv(t, `
auth:
mode: builtin
builtin:
admin:
username: admin
token:
secret: test-secret
basic:
username: basicuser
password: basicpass
paths:
usersDir: /tmp/users
`, nil)
require.Len(t, cfg.Warnings, 1)
assert.Contains(t, cfg.Warnings[0], "Basic auth configuration is ignored when auth mode is 'builtin'")
})
}
func TestLoad_LegacyEnv(t *testing.T) {
@ -731,6 +753,19 @@ func loadWithEnv(t *testing.T, yaml string, env map[string]string) *Config {
return loadFromYAML(t, yaml)
}
func unsetEnv(t *testing.T, key string) {
t.Helper()
original, existed := os.LookupEnv(key)
os.Unsetenv(key)
t.Cleanup(func() {
if existed {
os.Setenv(key, original)
return
}
os.Unsetenv(key)
})
}
// loadFromYAML loads config from YAML string
func loadFromYAML(t *testing.T, yaml string) *Config {
t.Helper()
@ -963,3 +998,19 @@ auth:
assert.Equal(t, 24*time.Hour, cfg.Server.Auth.Builtin.Token.TTL)
})
}
func TestLoad_AuthTokenEnvExpansion(t *testing.T) {
t.Parallel()
unsetEnv(t, "DAGU_AUTH_TOKEN")
unsetEnv(t, "DAGU_AUTHTOKEN")
cfg := loadWithEnv(t, `
auth:
token:
value: "${AUTH_TOKEN}"
`, map[string]string{
"AUTH_TOKEN": "env-token",
})
assert.Equal(t, "env-token", cfg.Server.Auth.Token.Value)
}

View File

@ -1,33 +0,0 @@
package config
import "sync"
// globalViperMu serializes all access to the shared viper instance used across the application.
// Uses RWMutex to allow concurrent reads while serializing writes.
var globalViperMu sync.RWMutex
// lockViper acquires the global viper write lock (private).
func lockViper() {
globalViperMu.Lock()
}
// unlockViper releases the global viper write lock (private).
func unlockViper() {
globalViperMu.Unlock()
}
// WithViperLock runs the provided function while holding the global viper write lock.
// Use this for any operations that modify viper state (Set, BindPFlag, ReadInConfig, etc.).
func WithViperLock(fn func()) {
lockViper()
defer unlockViper()
fn()
}
// WithViperRLock runs the provided function while holding the global viper read lock.
// Use this for read-only operations (Get, IsSet, etc.) to allow concurrent access.
func WithViperRLock(fn func()) {
globalViperMu.RLock()
defer globalViperMu.RUnlock()
fn()
}

View File

@ -221,6 +221,8 @@ func TestRegistry_Providers(t *testing.T) {
assert.Contains(t, providers, "custom")
}
var _ Resolver = (*mockResolver)(nil)
// mockResolver is a test double for the Resolver interface
type mockResolver struct {
mockName string

View File

@ -68,6 +68,38 @@ func RemoveQuotes(s string) string {
return s
}
// ScreamingSnakeToCamel converts a SCREAMING_SNAKE_CASE string to camelCase.
// Example: TOTAL_COUNT -> totalCount, MY_VAR -> myVar, FOO -> foo
func ScreamingSnakeToCamel(s string) string {
parts := strings.Split(s, "_")
if len(parts) == 0 {
return ""
}
var result strings.Builder
isFirst := true
for _, part := range parts {
if part == "" {
continue
}
lower := strings.ToLower(part)
if isFirst {
result.WriteString(lower)
isFirst = false
} else {
// Capitalize first letter of subsequent parts
if len(lower) > 0 {
result.WriteString(strings.ToUpper(lower[:1]))
if len(lower) > 1 {
result.WriteString(lower[1:])
}
}
}
}
return result.String()
}
// KebabToCamel converts a kebab-case string to camelCase.
func KebabToCamel(s string) string {
parts := strings.Split(s, "-")

View File

@ -141,6 +141,33 @@ func TestKebabToCamel(t *testing.T) {
})
}
func TestScreamingSnakeToCamel(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"TOTAL_COUNT", "totalCount"},
{"MY_VAR", "myVar"},
{"FOO", "foo"},
{"FOO_BAR_BAZ", "fooBarBaz"},
{"", ""},
{"_LEADING", "leading"},
{"TRAILING_", "trailing"},
{"DOUBLE__UNDERSCORE", "doubleUnderscore"},
{"already_lower", "alreadyLower"},
{"MiXeD_CaSe", "mixedCase"},
{"A", "a"},
{"A_B_C", "aBC"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := stringutil.ScreamingSnakeToCamel(tt.input)
require.Equal(t, tt.expected, result)
})
}
}
func TestIsMultiLine(t *testing.T) {
tests := []struct {
name string

View File

@ -16,6 +16,8 @@ import (
"github.com/dagu-org/dagu/internal/core/spec"
)
var _ execution.DAGStore = (*mockDAGStore)(nil)
// Mock implementations
type mockDAGStore struct {
mock.Mock
@ -151,9 +153,12 @@ func (m *mockDAGRunStore) FindSubAttempt(ctx context.Context, dagRun execution.D
return args.Get(0).(execution.DAGRunAttempt), args.Error(1)
}
func (m *mockDAGRunStore) RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int) error {
args := m.Called(ctx, name, retentionDays)
return args.Error(0)
func (m *mockDAGRunStore) RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int, opts ...execution.RemoveOldDAGRunsOption) ([]string, error) {
args := m.Called(ctx, name, retentionDays, opts)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]string), args.Error(1)
}
func (m *mockDAGRunStore) RenameDAGRuns(ctx context.Context, oldName, newName string) error {
@ -218,6 +223,8 @@ func (m *mockQueueStore) All(ctx context.Context) ([]execution.QueuedItemData, e
return args.Get(0).([]execution.QueuedItemData), args.Error(1)
}
var _ execution.ServiceRegistry = (*mockServiceRegistry)(nil)
type mockServiceRegistry struct {
mock.Mock
}

View File

@ -0,0 +1,83 @@
package core
// ExecutorCapabilities defines what an executor can do.
type ExecutorCapabilities struct {
// Command indicates whether the executor supports the command field.
Command bool
// MultipleCommands indicates whether the executor supports multiple commands.
MultipleCommands bool
// Script indicates whether the executor supports the script field.
Script bool
// Shell indicates whether the executor uses shell/shellArgs/shellPackages.
Shell bool
// Container indicates whether the executor supports step-level container config.
Container bool
// SubDAG indicates whether the executor can execute sub-DAGs.
SubDAG bool
// WorkerSelector indicates whether the executor supports worker selection.
WorkerSelector bool
}
// executorCapabilitiesRegistry is a typed registry of executor capabilities.
type executorCapabilitiesRegistry struct {
caps map[string]ExecutorCapabilities
}
var executorCapabilities = executorCapabilitiesRegistry{
caps: make(map[string]ExecutorCapabilities),
}
// Register registers capabilities for an executor type.
func (r *executorCapabilitiesRegistry) Register(executorType string, caps ExecutorCapabilities) {
r.caps[executorType] = caps
}
// Get returns capabilities for an executor type.
// Returns an empty ExecutorCapabilities if not registered.
func (r *executorCapabilitiesRegistry) Get(executorType string) ExecutorCapabilities {
if caps, ok := r.caps[executorType]; ok {
return caps
}
// Default: return all false (strict mode)
return ExecutorCapabilities{}
}
// RegisterExecutorCapabilities registers capabilities for an executor type.
func RegisterExecutorCapabilities(executorType string, caps ExecutorCapabilities) {
executorCapabilities.Register(executorType, caps)
}
// SupportsCommand returns whether the executor type supports the command field.
func SupportsCommand(executorType string) bool {
return executorCapabilities.Get(executorType).Command
}
// SupportsMultipleCommands returns whether the executor type supports multiple commands.
func SupportsMultipleCommands(executorType string) bool {
return executorCapabilities.Get(executorType).MultipleCommands
}
// SupportsScript returns whether the executor type supports the script field.
func SupportsScript(executorType string) bool {
return executorCapabilities.Get(executorType).Script
}
// SupportsShell returns whether the executor type uses shell configuration.
func SupportsShell(executorType string) bool {
return executorCapabilities.Get(executorType).Shell
}
// SupportsContainer returns whether the executor type supports step-level container config.
func SupportsContainer(executorType string) bool {
return executorCapabilities.Get(executorType).Container
}
// SupportsSubDAG returns whether the executor type can execute sub-DAGs.
func SupportsSubDAG(executorType string) bool {
return executorCapabilities.Get(executorType).SubDAG
}
// SupportsWorkerSelector returns whether the executor type supports worker selection.
func SupportsWorkerSelector(executorType string) bool {
return executorCapabilities.Get(executorType).WorkerSelector
}

View File

@ -0,0 +1,40 @@
package core
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExecutorCapabilities_Get(t *testing.T) {
registry := &executorCapabilitiesRegistry{
caps: make(map[string]ExecutorCapabilities),
}
// Test case 1: Registered executor
caps := ExecutorCapabilities{Command: true, MultipleCommands: true}
registry.Register("test-executor", caps)
assert.Equal(t, caps, registry.Get("test-executor"))
// Test case 2: Unregistered executor should return empty capabilities (strict default)
assert.Equal(t, ExecutorCapabilities{}, registry.Get("unregistered"))
}
func TestSupportsHelpers(t *testing.T) {
// Register a test executor with specific capabilities
caps := ExecutorCapabilities{
Command: true,
Script: false,
WorkerSelector: true,
}
RegisterExecutorCapabilities("helper-test", caps)
assert.True(t, SupportsCommand("helper-test"))
assert.False(t, SupportsScript("helper-test"))
assert.True(t, SupportsWorkerSelector("helper-test"))
// Unregistered executor should return false for everything
assert.False(t, SupportsCommand("unknown"))
assert.False(t, SupportsScript("unknown"))
assert.False(t, SupportsShell("unknown"))
}

View File

@ -7,6 +7,8 @@ import (
// Container defines the container configuration for the DAG.
type Container struct {
// Name is the container name to use. If empty, Docker generates a random name.
Name string `yaml:"name,omitempty"`
// Image is the container image to use.
Image string `yaml:"image,omitempty"`
// PullPolicy is the policy to pull the image (e.g., "Always", "IfNotPresent").

View File

@ -29,6 +29,36 @@ const (
TypeAgent = "agent"
)
// LogOutputMode represents the mode for log output handling.
// It determines how stdout and stderr are written to log files.
type LogOutputMode string
const (
// LogOutputSeparate keeps stdout and stderr in separate files (.out and .err).
// This is the default behavior for backward compatibility.
LogOutputSeparate LogOutputMode = "separate"
// LogOutputMerged combines stdout and stderr into a single log file (.log).
// Both streams are interleaved in the order they are written.
LogOutputMerged LogOutputMode = "merged"
)
// EffectiveLogOutput returns the effective log output mode for a step.
// It resolves the inheritance chain: step-level overrides DAG-level,
// and if neither is set, returns the default (LogOutputSeparate).
func EffectiveLogOutput(dag *DAG, step *Step) LogOutputMode {
// Step-level override takes precedence
if step != nil && step.LogOutput != "" {
return step.LogOutput
}
// Fall back to DAG-level setting
if dag != nil && dag.LogOutput != "" {
return dag.LogOutput
}
// Default to separate
return LogOutputSeparate
}
// DAG contains all information about a DAG.
type DAG struct {
// WorkingDir is the working directory to run the DAG.
@ -71,6 +101,10 @@ type DAG struct {
Env []string `json:"env,omitempty"`
// LogDir is the directory where the logs are stored.
LogDir string `json:"logDir,omitempty"`
// LogOutput specifies how stdout and stderr are handled in log files.
// Can be "separate" (default) for separate .out and .err files,
// or "merged" for a single combined .log file.
LogOutput LogOutputMode `json:"logOutput,omitempty"`
// DefaultParams contains the default parameters to be passed to the DAG.
DefaultParams string `json:"defaultParams,omitempty"`
// Params contains the list of parameters to be passed to the DAG.

View File

@ -26,7 +26,8 @@ func TestSockAddr(t *testing.T) {
dag := &core.DAG{
Location: "testdata/testDagVeryLongNameThatExceedsUnixSocketLengthMaximum-testDagVeryLongNameThatExceedsUnixSocketLengthMaximum.yml",
}
// 50 is the maximum length of a unix socket address
// 50 is an application-imposed limit to keep socket names short and portable
// (the system limit UNIX_PATH_MAX is typically 108 bytes on Linux)
require.LessOrEqual(t, 50, len(dag.SockAddr("")))
require.Equal(
t,
@ -305,45 +306,574 @@ func TestNextRun(t *testing.T) {
require.Equal(t, expectedNext, nextRun)
}
func TestAuthConfig(t *testing.T) {
t.Run("AuthConfigFields", func(t *testing.T) {
auth := &core.AuthConfig{
Username: "test-user",
Password: "test-pass",
Auth: "dGVzdC11c2VyOnRlc3QtcGFzcw==", // base64("test-user:test-pass")
}
func TestEffectiveLogOutput(t *testing.T) {
t.Parallel()
assert.Equal(t, "test-user", auth.Username)
assert.Equal(t, "test-pass", auth.Password)
assert.Equal(t, "dGVzdC11c2VyOnRlc3QtcGFzcw==", auth.Auth)
tests := []struct {
name string
dagLogOutput core.LogOutputMode
stepLogOutput core.LogOutputMode
expected core.LogOutputMode
}{
{
name: "BothEmpty_ReturnsSeparate",
dagLogOutput: "",
stepLogOutput: "",
expected: core.LogOutputSeparate,
},
{
name: "DAGSeparate_StepEmpty_ReturnsSeparate",
dagLogOutput: core.LogOutputSeparate,
stepLogOutput: "",
expected: core.LogOutputSeparate,
},
{
name: "DAGMerged_StepEmpty_ReturnsMerged",
dagLogOutput: core.LogOutputMerged,
stepLogOutput: "",
expected: core.LogOutputMerged,
},
{
name: "DAGEmpty_StepSeparate_ReturnsSeparate",
dagLogOutput: "",
stepLogOutput: core.LogOutputSeparate,
expected: core.LogOutputSeparate,
},
{
name: "DAGEmpty_StepMerged_ReturnsMerged",
dagLogOutput: "",
stepLogOutput: core.LogOutputMerged,
expected: core.LogOutputMerged,
},
{
name: "DAGSeparate_StepMerged_StepOverrides",
dagLogOutput: core.LogOutputSeparate,
stepLogOutput: core.LogOutputMerged,
expected: core.LogOutputMerged,
},
{
name: "DAGMerged_StepSeparate_StepOverrides",
dagLogOutput: core.LogOutputMerged,
stepLogOutput: core.LogOutputSeparate,
expected: core.LogOutputSeparate,
},
{
name: "NilDAG_StepMerged_ReturnsMerged",
dagLogOutput: "", // Will use nil DAG
stepLogOutput: core.LogOutputMerged,
expected: core.LogOutputMerged,
},
{
name: "NilStep_DAGMerged_ReturnsMerged",
dagLogOutput: core.LogOutputMerged,
stepLogOutput: "", // Will use nil Step
expected: core.LogOutputMerged,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var dag *core.DAG
var step *core.Step
// Setup DAG
if tt.name != "NilDAG_StepMerged_ReturnsMerged" {
dag = &core.DAG{LogOutput: tt.dagLogOutput}
}
// Setup Step
if tt.name != "NilStep_DAGMerged_ReturnsMerged" {
step = &core.Step{LogOutput: tt.stepLogOutput}
}
result := core.EffectiveLogOutput(dag, step)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDAG_Validate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
dag *core.DAG
wantErr bool
errMsg string
}{
{
name: "valid DAG with name passes",
dag: &core.DAG{
Name: "test-dag",
Steps: []core.Step{
{Name: "step1"},
},
},
wantErr: false,
},
{
name: "empty name fails",
dag: &core.DAG{Name: ""},
wantErr: true,
errMsg: "DAG name is required",
},
{
name: "valid dependencies pass",
dag: &core.DAG{
Name: "test-dag",
Steps: []core.Step{
{Name: "step1"},
{Name: "step2", Depends: []string{"step1"}},
},
},
wantErr: false,
},
{
name: "missing dependency fails",
dag: &core.DAG{
Name: "test-dag",
Steps: []core.Step{
{Name: "step1"},
{Name: "step2", Depends: []string{"nonexistent"}},
},
},
wantErr: true,
errMsg: "non-existent step",
},
{
name: "complex multi-level dependencies pass",
dag: &core.DAG{
Name: "test-dag",
Steps: []core.Step{
{Name: "step1"},
{Name: "step2", Depends: []string{"step1"}},
{Name: "step3", Depends: []string{"step1", "step2"}},
{Name: "step4", Depends: []string{"step3"}},
},
},
wantErr: false,
},
{
name: "steps with no dependencies pass",
dag: &core.DAG{
Name: "test-dag",
Steps: []core.Step{
{Name: "step1"},
{Name: "step2"},
{Name: "step3"},
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := tt.dag.Validate()
if tt.wantErr {
require.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestDAG_HasTag(t *testing.T) {
t.Parallel()
tests := []struct {
name string
tags []string
search string
expected bool
}{
{
name: "empty tags, search for any returns false",
tags: []string{},
search: "test",
expected: false,
},
{
name: "has tag, search for it returns true",
tags: []string{"production", "critical"},
search: "production",
expected: true,
},
{
name: "has tag, search for different returns false",
tags: []string{"production", "critical"},
search: "staging",
expected: false,
},
{
name: "multiple tags, search for last one returns true",
tags: []string{"a", "b", "c", "d"},
search: "d",
expected: true,
},
{
name: "case sensitive check - exact match",
tags: []string{"Production"},
search: "Production",
expected: true,
},
{
name: "case sensitive check - different case returns false",
tags: []string{"Production"},
search: "production",
expected: false,
},
{
name: "nil tags returns false",
tags: nil,
search: "test",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Tags: tt.tags}
result := dag.HasTag(tt.search)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDAG_ParamsMap(t *testing.T) {
t.Parallel()
tests := []struct {
name string
params []string
expected map[string]string
}{
{
name: "empty params returns empty map",
params: []string{},
expected: map[string]string{},
},
{
name: "single param key=value",
params: []string{"key=value"},
expected: map[string]string{"key": "value"},
},
{
name: "multiple params",
params: []string{"key1=value1", "key2=value2", "key3=value3"},
expected: map[string]string{"key1": "value1", "key2": "value2", "key3": "value3"},
},
{
name: "param with multiple equals - first splits",
params: []string{"key=value=with=equals"},
expected: map[string]string{"key": "value=with=equals"},
},
{
name: "param without equals - excluded",
params: []string{"noequals"},
expected: map[string]string{},
},
{
name: "mixed valid and invalid params",
params: []string{"valid=value", "invalid", "another=one"},
expected: map[string]string{"valid": "value", "another": "one"},
},
{
name: "empty value",
params: []string{"key="},
expected: map[string]string{"key": ""},
},
{
name: "nil params returns empty map",
params: nil,
expected: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Params: tt.params}
result := dag.ParamsMap()
assert.Equal(t, tt.expected, result)
})
}
}
func TestDAG_ProcGroup(t *testing.T) {
t.Parallel()
tests := []struct {
name string
queue string
dagName string
expected string
}{
{
name: "queue set returns queue",
queue: "my-queue",
dagName: "my-dag",
expected: "my-queue",
},
{
name: "queue empty returns dag name",
queue: "",
dagName: "my-dag",
expected: "my-dag",
},
{
name: "both empty returns empty string",
queue: "",
dagName: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Queue: tt.queue, Name: tt.dagName}
result := dag.ProcGroup()
assert.Equal(t, tt.expected, result)
})
}
}
func TestDAG_FileName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
location string
expected string
}{
{
name: "location with .yaml extension",
location: "/path/to/mydag.yaml",
expected: "mydag",
},
{
name: "location with .yml extension",
location: "/path/to/mydag.yml",
expected: "mydag",
},
{
name: "location with no extension",
location: "/path/to/mydag",
expected: "mydag",
},
{
name: "empty location returns empty string",
location: "",
expected: "",
},
{
name: "just filename with yaml",
location: "simple.yaml",
expected: "simple",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Location: tt.location}
result := dag.FileName()
assert.Equal(t, tt.expected, result)
})
}
}
func TestDAG_String(t *testing.T) {
t.Parallel()
t.Run("full DAG formatted output", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{
Name: "test-dag",
Description: "A test DAG",
Params: []string{"param1=value1", "param2=value2"},
LogDir: "/var/log/dags",
Steps: []core.Step{
{Name: "step1"},
{Name: "step2"},
},
}
result := dag.String()
// Verify key fields are included
assert.Contains(t, result, "test-dag")
assert.Contains(t, result, "A test DAG")
assert.Contains(t, result, "param1=value1")
assert.Contains(t, result, "/var/log/dags")
})
t.Run("minimal DAG basic output", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{
Name: "minimal",
}
result := dag.String()
assert.Contains(t, result, "minimal")
assert.Contains(t, result, "{")
assert.Contains(t, result, "}")
})
}
func TestDAGRegistryAuths(t *testing.T) {
t.Run("DAGWithRegistryAuths", func(t *testing.T) {
func TestDAG_InitializeDefaults(t *testing.T) {
t.Parallel()
t.Run("empty DAG sets all defaults", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{}
core.InitializeDefaults(dag)
assert.Equal(t, core.TypeChain, dag.Type)
assert.Equal(t, 30, dag.HistRetentionDays)
assert.Equal(t, 5*time.Second, dag.MaxCleanUpTime)
assert.Equal(t, 1, dag.MaxActiveRuns)
assert.Equal(t, 1024*1024, dag.MaxOutputSize)
})
t.Run("pre-existing Type not overwritten", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Type: core.TypeGraph}
core.InitializeDefaults(dag)
assert.Equal(t, core.TypeGraph, dag.Type)
})
t.Run("pre-existing HistRetentionDays not overwritten", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{HistRetentionDays: 90}
core.InitializeDefaults(dag)
assert.Equal(t, 90, dag.HistRetentionDays)
})
t.Run("pre-existing MaxActiveRuns not overwritten", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{MaxActiveRuns: 5}
core.InitializeDefaults(dag)
assert.Equal(t, 5, dag.MaxActiveRuns)
})
t.Run("negative MaxActiveRuns preserved (queueing disabled)", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{MaxActiveRuns: -1}
core.InitializeDefaults(dag)
// Negative values mean queueing is disabled, should be preserved
assert.Equal(t, -1, dag.MaxActiveRuns)
})
}
func TestDAG_NextRun_Extended(t *testing.T) {
t.Parallel()
t.Run("empty schedule returns zero time", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Schedule: []core.Schedule{}}
now := time.Now()
result := dag.NextRun(now)
assert.True(t, result.IsZero())
})
t.Run("single schedule returns correct next time", func(t *testing.T) {
t.Parallel()
// Schedule for every hour at minute 0
parsed, err := cron.ParseStandard("0 * * * *")
require.NoError(t, err)
dag := &core.DAG{
Name: "test-dag",
RegistryAuths: map[string]*core.AuthConfig{
"docker.io": {
Username: "docker-user",
Password: "docker-pass",
},
"ghcr.io": {
Username: "github-user",
Password: "github-token",
},
Schedule: []core.Schedule{
{Expression: "0 * * * *", Parsed: parsed},
},
}
assert.NotNil(t, dag.RegistryAuths)
assert.Len(t, dag.RegistryAuths, 2)
now := time.Date(2023, 10, 1, 12, 30, 0, 0, time.UTC)
result := dag.NextRun(now)
dockerAuth, exists := dag.RegistryAuths["docker.io"]
assert.True(t, exists)
assert.Equal(t, "docker-user", dockerAuth.Username)
// Should be the next hour
expected := time.Date(2023, 10, 1, 13, 0, 0, 0, time.UTC)
assert.Equal(t, expected, result)
})
ghcrAuth, exists := dag.RegistryAuths["ghcr.io"]
assert.True(t, exists)
assert.Equal(t, "github-user", ghcrAuth.Username)
t.Run("multiple schedules returns earliest", func(t *testing.T) {
t.Parallel()
// First schedule: every hour at minute 0
hourly, err := cron.ParseStandard("0 * * * *")
require.NoError(t, err)
// Second schedule: every 30 minutes
halfHourly, err := cron.ParseStandard("*/30 * * * *")
require.NoError(t, err)
dag := &core.DAG{
Schedule: []core.Schedule{
{Expression: "0 * * * *", Parsed: hourly},
{Expression: "*/30 * * * *", Parsed: halfHourly},
},
}
now := time.Date(2023, 10, 1, 12, 15, 0, 0, time.UTC)
result := dag.NextRun(now)
// Should be at 12:30 (every 30 min) before 13:00 (hourly)
expected := time.Date(2023, 10, 1, 12, 30, 0, 0, time.UTC)
assert.Equal(t, expected, result)
})
t.Run("nil Parsed in schedule is skipped", func(t *testing.T) {
t.Parallel()
parsed, err := cron.ParseStandard("0 * * * *")
require.NoError(t, err)
dag := &core.DAG{
Schedule: []core.Schedule{
{Expression: "invalid", Parsed: nil}, // nil Parsed should be skipped
{Expression: "0 * * * *", Parsed: parsed}, // valid
},
}
now := time.Date(2023, 10, 1, 12, 30, 0, 0, time.UTC)
result := dag.NextRun(now)
expected := time.Date(2023, 10, 1, 13, 0, 0, 0, time.UTC)
assert.Equal(t, expected, result)
})
}
func TestDAG_GetName(t *testing.T) {
t.Parallel()
t.Run("name set returns name", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Name: "my-dag", Location: "/path/to/other.yaml"}
assert.Equal(t, "my-dag", dag.GetName())
})
t.Run("name empty returns filename from location", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Name: "", Location: "/path/to/mydag.yaml"}
assert.Equal(t, "mydag", dag.GetName())
})
t.Run("name empty and location empty returns empty", func(t *testing.T) {
t.Parallel()
dag := &core.DAG{Name: "", Location: ""}
assert.Equal(t, "", dag.GetName())
})
}

View File

@ -0,0 +1,460 @@
package core
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestErrorList_Error(t *testing.T) {
t.Parallel()
tests := []struct {
name string
errList ErrorList
expected string
}{
{
name: "empty list returns empty string",
errList: ErrorList{},
expected: "",
},
{
name: "single error returns error message",
errList: ErrorList{errors.New("first error")},
expected: "first error",
},
{
name: "multiple errors joined with semicolon",
errList: ErrorList{errors.New("first"), errors.New("second"), errors.New("third")},
expected: "first; second; third",
},
{
name: "two errors joined with semicolon",
errList: ErrorList{errors.New("error1"), errors.New("error2")},
expected: "error1; error2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := tt.errList.Error()
assert.Equal(t, tt.expected, result)
})
}
}
func TestErrorList_ToStringList(t *testing.T) {
t.Parallel()
tests := []struct {
name string
errList ErrorList
expected []string
}{
{
name: "empty list returns empty slice",
errList: ErrorList{},
expected: []string{},
},
{
name: "single error returns slice with one string",
errList: ErrorList{errors.New("single error")},
expected: []string{"single error"},
},
{
name: "multiple errors returns slice with all strings",
errList: ErrorList{errors.New("first"), errors.New("second"), errors.New("third")},
expected: []string{"first", "second", "third"},
},
{
name: "preserves order of errors",
errList: ErrorList{errors.New("a"), errors.New("b"), errors.New("c")},
expected: []string{"a", "b", "c"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := tt.errList.ToStringList()
assert.Equal(t, tt.expected, result)
})
}
}
func TestErrorList_Unwrap(t *testing.T) {
t.Parallel()
t.Run("empty list returns nil", func(t *testing.T) {
t.Parallel()
errList := ErrorList{}
result := errList.Unwrap()
assert.Nil(t, result)
})
t.Run("single error returns slice with one error", func(t *testing.T) {
t.Parallel()
err := errors.New("test error")
errList := ErrorList{err}
result := errList.Unwrap()
require.Len(t, result, 1)
assert.Equal(t, err, result[0])
})
t.Run("multiple errors returns slice with all errors", func(t *testing.T) {
t.Parallel()
err1 := errors.New("error 1")
err2 := errors.New("error 2")
err3 := errors.New("error 3")
errList := ErrorList{err1, err2, err3}
result := errList.Unwrap()
require.Len(t, result, 3)
assert.Equal(t, err1, result[0])
assert.Equal(t, err2, result[1])
assert.Equal(t, err3, result[2])
})
t.Run("errors.Is works with unwrapped errors", func(t *testing.T) {
t.Parallel()
targetErr := ErrStepNameRequired
errList := ErrorList{errors.New("other error"), targetErr}
// errors.Is should find the target error in the list
assert.True(t, errors.Is(errList, targetErr))
})
t.Run("errors.Is returns false for non-existent error", func(t *testing.T) {
t.Parallel()
errList := ErrorList{errors.New("other error")}
// errors.Is should not find ErrStepNameRequired
assert.False(t, errors.Is(errList, ErrStepNameRequired))
})
}
func TestValidationError_Error(t *testing.T) {
t.Parallel()
tests := []struct {
name string
field string
value any
err error
expected string
}{
{
name: "nil value formats without value",
field: "testField",
value: nil,
err: errors.New("test error"),
expected: "field 'testField': test error",
},
{
name: "string value formats with value",
field: "name",
value: "my-dag",
err: errors.New("name too long"),
expected: "field 'name': name too long (value: my-dag)",
},
{
name: "int value formats with value",
field: "maxRetries",
value: 5,
err: errors.New("value out of range"),
expected: "field 'maxRetries': value out of range (value: 5)",
},
{
name: "empty field name",
field: "",
value: "test",
err: errors.New("invalid"),
expected: "field '': invalid (value: test)",
},
{
name: "struct value uses %+v format",
field: "config",
value: struct{ Name string }{Name: "test"},
err: errors.New("invalid config"),
expected: "field 'config': invalid config (value: {Name:test})",
},
{
name: "slice value",
field: "tags",
value: []string{"a", "b"},
err: errors.New("invalid tags"),
expected: "field 'tags': invalid tags (value: [a b])",
},
{
name: "bool value",
field: "enabled",
value: true,
err: errors.New("cannot enable"),
expected: "field 'enabled': cannot enable (value: true)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ve := &ValidationError{
Field: tt.field,
Value: tt.value,
Err: tt.err,
}
result := ve.Error()
assert.Equal(t, tt.expected, result)
})
}
}
func TestValidationError_Unwrap(t *testing.T) {
t.Parallel()
t.Run("returns underlying error", func(t *testing.T) {
t.Parallel()
underlyingErr := errors.New("underlying error")
ve := &ValidationError{
Field: "test",
Value: nil,
Err: underlyingErr,
}
result := ve.Unwrap()
assert.Equal(t, underlyingErr, result)
})
t.Run("works with errors.Is", func(t *testing.T) {
t.Parallel()
ve := &ValidationError{
Field: "name",
Value: "test",
Err: ErrStepNameTooLong,
}
assert.True(t, errors.Is(ve, ErrStepNameTooLong))
})
t.Run("works with errors.As", func(t *testing.T) {
t.Parallel()
ve := &ValidationError{
Field: "test",
Value: "value",
Err: errors.New("test"),
}
var targetErr *ValidationError
assert.True(t, errors.As(ve, &targetErr))
assert.Equal(t, "test", targetErr.Field)
})
t.Run("nil underlying error", func(t *testing.T) {
t.Parallel()
ve := &ValidationError{
Field: "test",
Value: nil,
Err: nil,
}
result := ve.Unwrap()
assert.Nil(t, result)
})
}
func TestNewValidationError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
field string
value any
err error
expectedField string
expectedValue any
}{
{
name: "creates validation error with all fields",
field: "steps",
value: "step1",
err: ErrStepNameDuplicate,
expectedField: "steps",
expectedValue: "step1",
},
{
name: "creates validation error with nil value",
field: "name",
value: nil,
err: ErrStepNameRequired,
expectedField: "name",
expectedValue: nil,
},
{
name: "creates validation error with empty field",
field: "",
value: 123,
err: errors.New("test"),
expectedField: "",
expectedValue: 123,
},
{
name: "creates validation error with nil error",
field: "test",
value: "value",
err: nil,
expectedField: "test",
expectedValue: "value",
},
{
name: "creates validation error with complex value",
field: "config",
value: map[string]int{"a": 1, "b": 2},
err: errors.New("invalid"),
expectedField: "config",
expectedValue: map[string]int{"a": 1, "b": 2},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := NewValidationError(tt.field, tt.value, tt.err)
// Assert it returns a ValidationError
var ve *ValidationError
require.True(t, errors.As(err, &ve))
assert.Equal(t, tt.expectedField, ve.Field)
assert.Equal(t, tt.expectedValue, ve.Value)
assert.Equal(t, tt.err, ve.Err)
})
}
}
func TestErrorConstants(t *testing.T) {
t.Parallel()
// Test that all error constants are defined and have meaningful messages
errorConstants := []struct {
name string
err error
}{
{"ErrNameTooLong", ErrNameTooLong},
{"ErrNameInvalidChars", ErrNameInvalidChars},
{"ErrInvalidSchedule", ErrInvalidSchedule},
{"ErrScheduleMustBeStringOrArray", ErrScheduleMustBeStringOrArray},
{"ErrInvalidScheduleType", ErrInvalidScheduleType},
{"ErrInvalidKeyType", ErrInvalidKeyType},
{"ErrExecutorConfigMustBeString", ErrExecutorConfigMustBeString},
{"ErrDuplicateFunction", ErrDuplicateFunction},
{"ErrFuncParamsMismatch", ErrFuncParamsMismatch},
{"ErrInvalidStepData", ErrInvalidStepData},
{"ErrStepNameRequired", ErrStepNameRequired},
{"ErrStepNameDuplicate", ErrStepNameDuplicate},
{"ErrStepNameTooLong", ErrStepNameTooLong},
{"ErrStepCommandIsRequired", ErrStepCommandIsRequired},
{"ErrStepCommandIsEmpty", ErrStepCommandIsEmpty},
{"ErrStepCommandMustBeArrayOrString", ErrStepCommandMustBeArrayOrString},
{"ErrInvalidParamValue", ErrInvalidParamValue},
{"ErrCallFunctionNotFound", ErrCallFunctionNotFound},
{"ErrNumberOfParamsMismatch", ErrNumberOfParamsMismatch},
{"ErrRequiredParameterNotFound", ErrRequiredParameterNotFound},
{"ErrScheduleKeyMustBeString", ErrScheduleKeyMustBeString},
{"ErrInvalidSignal", ErrInvalidSignal},
{"ErrInvalidEnvValue", ErrInvalidEnvValue},
{"ErrArgsMustBeConvertibleToIntOrString", ErrArgsMustBeConvertibleToIntOrString},
{"ErrExecutorTypeMustBeString", ErrExecutorTypeMustBeString},
{"ErrExecutorConfigValueMustBeMap", ErrExecutorConfigValueMustBeMap},
{"ErrExecutorHasInvalidKey", ErrExecutorHasInvalidKey},
{"ErrExecutorConfigMustBeStringOrMap", ErrExecutorConfigMustBeStringOrMap},
{"ErrDotEnvMustBeStringOrArray", ErrDotEnvMustBeStringOrArray},
{"ErrPreconditionMustBeArrayOrString", ErrPreconditionMustBeArrayOrString},
{"ErrPreconditionValueMustBeString", ErrPreconditionValueMustBeString},
{"ErrPreconditionHasInvalidKey", ErrPreconditionHasInvalidKey},
{"ErrContinueOnOutputMustBeStringOrArray", ErrContinueOnOutputMustBeStringOrArray},
{"ErrContinueOnExitCodeMustBeIntOrArray", ErrContinueOnExitCodeMustBeIntOrArray},
{"ErrDependsMustBeStringOrArray", ErrDependsMustBeStringOrArray},
{"ErrStepsMustBeArrayOrMap", ErrStepsMustBeArrayOrMap},
}
for _, ec := range errorConstants {
t.Run(ec.name, func(t *testing.T) {
t.Parallel()
// Error should not be nil
require.NotNil(t, ec.err, "error constant %s should not be nil", ec.name)
// Error message should not be empty
msg := ec.err.Error()
assert.NotEmpty(t, msg, "error constant %s should have a non-empty message", ec.name)
// Error message should be meaningful (at least 5 characters)
assert.GreaterOrEqual(t, len(msg), 5,
"error constant %s message should be meaningful (at least 5 chars): %q", ec.name, msg)
})
}
}
func TestErrorList_ImplementsErrorInterface(t *testing.T) {
t.Parallel()
// Ensure ErrorList implements the error interface
var _ error = ErrorList{}
var _ error = &ErrorList{}
// Test that it can be used as an error
errList := ErrorList{errors.New("test")}
var err error = errList
assert.NotNil(t, err)
assert.Equal(t, "test", err.Error())
}
func TestValidationError_ImplementsErrorInterface(t *testing.T) {
t.Parallel()
// Ensure ValidationError implements the error interface
var _ error = &ValidationError{}
// Test that it can be used as an error
ve := &ValidationError{Field: "test", Err: errors.New("error")}
var err error = ve
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "test")
}
func TestErrorList_WithWrappedErrors(t *testing.T) {
t.Parallel()
t.Run("contains wrapped validation error", func(t *testing.T) {
t.Parallel()
ve := NewValidationError("field", "value", ErrStepNameRequired)
errList := ErrorList{ve}
// Should be able to find the wrapped error
assert.True(t, errors.Is(errList, ErrStepNameRequired))
})
t.Run("contains multiple wrapped errors", func(t *testing.T) {
t.Parallel()
ve1 := NewValidationError("field1", nil, ErrStepNameRequired)
ve2 := NewValidationError("field2", nil, ErrStepNameTooLong)
errList := ErrorList{ve1, ve2}
// Should find both wrapped errors
assert.True(t, errors.Is(errList, ErrStepNameRequired))
assert.True(t, errors.Is(errList, ErrStepNameTooLong))
})
t.Run("fmt.Errorf wrapped errors work", func(t *testing.T) {
t.Parallel()
wrapped := fmt.Errorf("context: %w", ErrInvalidSchedule)
errList := ErrorList{wrapped}
// Should find the wrapped error
assert.True(t, errors.Is(errList, ErrInvalidSchedule))
})
}

View File

@ -33,11 +33,12 @@ type DAGRunStore interface {
FindAttempt(ctx context.Context, dagRun DAGRunRef) (DAGRunAttempt, error)
// FindSubAttempt finds a sub dag-run record by dag-run ID.
FindSubAttempt(ctx context.Context, dagRun DAGRunRef, subDAGRunID string) (DAGRunAttempt, error)
// RemoveOldDAGRuns delete dag-run records older than retentionDays
// RemoveOldDAGRuns deletes dag-run records older than retentionDays.
// If retentionDays is negative, it won't delete any records.
// If retentionDays is zero, it will delete all records for the DAG name.
// But it will not delete the records with non-final statuses (e.g., running, queued).
RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int) error
// Returns a list of dag-run IDs that were removed (or would be removed in dry-run mode).
RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int, opts ...RemoveOldDAGRunsOption) ([]string, error)
// RenameDAGRuns renames all run data from oldName to newName
// The name means the DAG name, renaming it will allow user to manage those runs
// with the new DAG name.
@ -102,6 +103,22 @@ func WithDAGRunID(dagRunID string) ListDAGRunStatusesOption {
}
}
// RemoveOldDAGRunsOptions contains options for removing old dag-runs
type RemoveOldDAGRunsOptions struct {
// DryRun if true, only returns the paths that would be removed without actually deleting
DryRun bool
}
// RemoveOldDAGRunsOption is a functional option for configuring RemoveOldDAGRunsOptions
type RemoveOldDAGRunsOption func(*RemoveOldDAGRunsOptions)
// WithDryRun sets the dry-run mode for removing old dag-runs
func WithDryRun() RemoveOldDAGRunsOption {
return func(o *RemoveOldDAGRunsOptions) {
o.DryRun = true
}
}
// NewDAGRunAttemptOptions contains options for creating a new run record
type NewDAGRunAttemptOptions struct {
// RootDAGRun is the root dag-run reference for this attempt.
@ -133,6 +150,12 @@ type DAGRunAttempt interface {
Hide(ctx context.Context) error
// Hidden returns true if the attempt is hidden from normal operations.
Hidden() bool
// WriteOutputs writes the collected step outputs for the dag-run.
// Does nothing if outputs is nil or has no output entries.
WriteOutputs(ctx context.Context, outputs *DAGRunOutputs) error
// ReadOutputs reads the collected step outputs for the dag-run.
// Returns nil if no outputs file exists or if the file is in v1 format.
ReadOutputs(ctx context.Context) (*DAGRunOutputs, error)
}
// Errors for RunRef parsing

View File

@ -67,9 +67,12 @@ func (m *mockDAGRunStore) FindSubAttempt(ctx context.Context, dagRun execution.D
return args.Get(0).(execution.DAGRunAttempt), args.Error(1)
}
func (m *mockDAGRunStore) RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int) error {
args := m.Called(ctx, name, retentionDays)
return args.Error(0)
func (m *mockDAGRunStore) RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int, opts ...execution.RemoveOldDAGRunsOption) ([]string, error) {
args := m.Called(ctx, name, retentionDays, opts)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]string), args.Error(1)
}
func (m *mockDAGRunStore) RenameDAGRuns(ctx context.Context, oldName, newName string) error {
@ -145,6 +148,19 @@ func (m *mockDAGRunAttempt) Hidden() bool {
return args.Bool(0)
}
func (m *mockDAGRunAttempt) WriteOutputs(ctx context.Context, outputs *execution.DAGRunOutputs) error {
args := m.Called(ctx, outputs)
return args.Error(0)
}
func (m *mockDAGRunAttempt) ReadOutputs(ctx context.Context) (*execution.DAGRunOutputs, error) {
args := m.Called(ctx)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*execution.DAGRunOutputs), args.Error(1)
}
// Tests
func TestListDAGRunStatusesOptions(t *testing.T) {
@ -242,10 +258,11 @@ func TestDAGRunStoreInterface(t *testing.T) {
assert.Equal(t, mockAttempt, childFound)
// Test RemoveOldDAGRuns
store.On("RemoveOldDAGRuns", ctx, "test-dag", 30).Return(nil)
store.On("RemoveOldDAGRuns", ctx, "test-dag", 30, mock.Anything).Return([]string{"run-1", "run-2"}, nil)
err = store.RemoveOldDAGRuns(ctx, "test-dag", 30)
removedIDs, err := store.RemoveOldDAGRuns(ctx, "test-dag", 30)
assert.NoError(t, err)
assert.Equal(t, []string{"run-1", "run-2"}, removedIDs)
// Test RenameDAGRuns
store.On("RenameDAGRuns", ctx, "old-name", "new-name").Return(nil)
@ -383,19 +400,22 @@ func TestRemoveOldDAGRunsEdgeCases(t *testing.T) {
store := &mockDAGRunStore{}
// Test with negative retention days (should not delete anything)
store.On("RemoveOldDAGRuns", ctx, "test-dag", -1).Return(nil)
err := store.RemoveOldDAGRuns(ctx, "test-dag", -1)
store.On("RemoveOldDAGRuns", ctx, "test-dag", -1, mock.Anything).Return([]string(nil), nil)
removedIDs, err := store.RemoveOldDAGRuns(ctx, "test-dag", -1)
assert.NoError(t, err)
assert.Nil(t, removedIDs)
// Test with zero retention days (should delete all except non-final statuses)
store.On("RemoveOldDAGRuns", ctx, "test-dag", 0).Return(nil)
err = store.RemoveOldDAGRuns(ctx, "test-dag", 0)
store.On("RemoveOldDAGRuns", ctx, "test-dag", 0, mock.Anything).Return([]string{"run-1", "run-2"}, nil)
removedIDs, err = store.RemoveOldDAGRuns(ctx, "test-dag", 0)
assert.NoError(t, err)
assert.Equal(t, []string{"run-1", "run-2"}, removedIDs)
// Test with positive retention days
store.On("RemoveOldDAGRuns", ctx, "test-dag", 30).Return(nil)
err = store.RemoveOldDAGRuns(ctx, "test-dag", 30)
store.On("RemoveOldDAGRuns", ctx, "test-dag", 30, mock.Anything).Return([]string{"run-old"}, nil)
removedIDs, err = store.RemoveOldDAGRuns(ctx, "test-dag", 30)
assert.NoError(t, err)
assert.Equal(t, []string{"run-old"}, removedIDs)
store.AssertExpectations(t)
}

View File

@ -0,0 +1,17 @@
package execution
// DAGRunOutputs represents the full outputs file structure with metadata.
type DAGRunOutputs struct {
Metadata OutputsMetadata `json:"metadata"`
Outputs map[string]string `json:"outputs"`
}
// OutputsMetadata contains execution context for the outputs.
type OutputsMetadata struct {
DAGName string `json:"dagName"`
DAGRunID string `json:"dagRunId"`
AttemptID string `json:"attemptId"`
Status string `json:"status"`
CompletedAt string `json:"completedAt"`
Params string `json:"params,omitempty"` // JSON-serialized parameters
}

View File

@ -51,6 +51,7 @@ type DAGRunStatus struct {
StartedAt string `json:"startedAt,omitempty"`
FinishedAt string `json:"finishedAt,omitempty"`
Log string `json:"log,omitempty"`
Error string `json:"error,omitempty"`
Params string `json:"params,omitempty"`
ParamsList []string `json:"paramsList,omitempty"`
Preconditions []*core.Condition `json:"preconditions,omitempty"`
@ -64,6 +65,9 @@ func (st *DAGRunStatus) DAGRun() DAGRunRef {
// Errors returns a slice of errors for the current status
func (st *DAGRunStatus) Errors() []error {
var errs []error
if st.Error != "" {
errs = append(errs, fmt.Errorf("%s", st.Error))
}
for _, node := range st.Nodes {
if node.Error != "" {
errs = append(errs, fmt.Errorf("node %s: %s", node.Step.Name, node.Error))

View File

@ -1,25 +0,0 @@
package core
import "regexp"
// DAGNameMaxLen defines the maximum allowed length for a DAG name.
const DAGNameMaxLen = 40
// dagNameRegex matches valid DAG names: alphanumeric, underscore, dash, dot.
var dagNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
// ValidateDAGName validates a DAG name according to shared rules.
// - Empty name is allowed (caller may provide one via context or filename).
// - Non-empty name must satisfy length and allowed character constraints.
func ValidateDAGName(name string) error {
if name == "" {
return nil
}
if len(name) > DAGNameMaxLen {
return ErrNameTooLong
}
if !dagNameRegex.MatchString(name) {
return ErrNameInvalidChars
}
return nil
}

View File

@ -118,7 +118,7 @@ steps:
parallel: ${ITEMS}
`,
wantErr: true,
wantErrMsg: "parallel execution is only supported for child-DAGs",
wantErrMsg: "cannot use sub-DAG field",
},
{
name: "ErrorParallelWithoutCommandOrRun",

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,69 +0,0 @@
package spec_test
import (
"context"
"testing"
"time"
"github.com/dagu-org/dagu/internal/core/spec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStepTimeout(t *testing.T) {
t.Parallel()
// Positive timeout
t.Run("PositiveTimeout", func(t *testing.T) {
data := []byte(`
steps:
- name: work
command: echo doing
timeoutSec: 5
`)
dag, err := spec.LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Len(t, dag.Steps, 1)
assert.Equal(t, 5*time.Second, dag.Steps[0].Timeout)
})
// Zero timeout (explicit) -> unset/zero duration
t.Run("ZeroTimeoutExplicit", func(t *testing.T) {
data := []byte(`
steps:
- name: work
command: echo none
timeoutSec: 0
`)
dag, err := spec.LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Len(t, dag.Steps, 1)
assert.Equal(t, time.Duration(0), dag.Steps[0].Timeout)
})
// Zero timeout (omitted) -> also zero
t.Run("ZeroTimeoutOmitted", func(t *testing.T) {
data := []byte(`
steps:
- name: work
command: echo omitted
`)
dag, err := spec.LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Len(t, dag.Steps, 1)
assert.Equal(t, time.Duration(0), dag.Steps[0].Timeout)
})
// Negative timeout should fail validation
t.Run("NegativeTimeout", func(t *testing.T) {
data := []byte(`
steps:
- name: bad
command: echo fail
timeoutSec: -3
`)
_, err := spec.LoadYAML(context.Background(), data)
require.Error(t, err)
assert.Contains(t, err.Error(), "timeoutSec must be >= 0")
})
}

View File

@ -1,106 +0,0 @@
package spec
import (
"fmt"
"strings"
"github.com/dagu-org/dagu/internal/common/cmdutil"
"github.com/dagu-org/dagu/internal/core"
)
// buildCommand parses the command field in the step definition.
// Case 1: command is nil
// Case 2: command is a string
// Case 3: command is an array
//
// In case 3, the first element is the command and the rest are the arguments.
// If the arguments are not strings, they are converted to strings.
//
// Example:
// ```yaml
// step:
// - name: "echo hello"
// command: "echo hello"
//
// ```
// or
// ```yaml
// step:
// - name: "echo hello"
// command: ["echo", "hello"]
//
// ```
// It returns an error if the command is not nil but empty.
func buildCommand(_ StepBuildContext, def stepDef, step *core.Step) error {
command := def.Command
// Case 1: command is nil
if command == nil {
return nil
}
switch val := command.(type) {
case string:
// Case 2: command is a string
val = strings.TrimSpace(val)
if val == "" {
return core.NewValidationError("command", val, ErrStepCommandIsEmpty)
}
// If the value is multi-line, treat it as a script
if strings.Contains(val, "\n") {
step.Script = val
return nil
}
// We need to split the command into command and args.
step.CmdWithArgs = val
cmd, args, err := cmdutil.SplitCommand(val)
if err != nil {
return core.NewValidationError("command", val, fmt.Errorf("failed to parse command: %w", err))
}
step.Command = strings.TrimSpace(cmd)
step.Args = args
case []any:
// Case 3: command is an array
var command string
var args []string
for _, v := range val {
val, ok := v.(string)
if !ok {
// If the value is not a string, convert it to a string.
// This is useful when the value is an integer for example.
val = fmt.Sprintf("%v", v)
}
val = strings.TrimSpace(val)
if command == "" {
command = val
continue
}
args = append(args, val)
}
// Setup CmdWithArgs (this will be actually used in the command execution)
var sb strings.Builder
for i, arg := range step.Args {
if i > 0 {
sb.WriteString(" ")
}
sb.WriteString(fmt.Sprintf("%q", arg))
}
step.Command = command
step.Args = args
step.CmdWithArgs = fmt.Sprintf("%s %s", step.Command, sb.String())
step.CmdArgsSys = cmdutil.JoinCommandArgs(step.Command, step.Args)
default:
// Unknown type for command field.
return core.NewValidationError("command", val, ErrStepCommandMustBeArrayOrString)
}
return nil
}

1173
internal/core/spec/dag.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,288 +0,0 @@
package spec
// definition is a temporary struct to hold the DAG definition.
// This struct is used to unmarshal the YAML data.
// The data is then converted to the DAG struct.
type definition struct {
// Name is the name of the DAG.
Name string
// Group is the group of the DAG for grouping DAGs on the UI.
Group string
// Description is the description of the DAG.
Description string
// Type is the execution type for steps (graph, chain, or agent).
// Default is "graph" which uses dependency-based execution.
// "chain" executes steps in the order they are defined.
// "agent" is reserved for future agent-based execution.
Type string
// Shell is the default shell to use for all steps in this DAG.
// If not specified, the system default shell is used.
// Can be overridden at the step level.
// Can be a string (e.g., "bash -e") or an array (e.g., ["bash", "-e"]).
Shell any
// WorkingDir is working directory for DAG execution
WorkingDir string
// Dotenv is the path to the dotenv file (string or []string).
Dotenv any
// Schedule is the cron schedule to run the DAG.
Schedule any
// SkipIfSuccessful is the flag to skip the DAG on schedule when it is
// executed manually before the schedule.
SkipIfSuccessful bool
// LogFile is the file to write the log.
LogDir string
// Env is the environment variables setting.
Env any
// HandlerOn is the handler configuration.
HandlerOn handlerOnDef
// Steps is the list of steps to run.
Steps any // []stepDef or map[string]stepDef
// SMTP is the SMTP configuration.
SMTP smtpConfigDef
// MailOn is the mail configuration.
MailOn *mailOnDef
// ErrorMail is the mail configuration for error.
ErrorMail mailConfigDef
// InfoMail is the mail configuration for information.
InfoMail mailConfigDef
// TimeoutSec is the timeout in seconds to finish the DAG.
TimeoutSec int
// DelaySec is the delay in seconds to start the first node.
DelaySec int
// RestartWaitSec is the wait in seconds to when the DAG is restarted.
RestartWaitSec int
// HistRetentionDays is the retention days of the dag-runs history.
HistRetentionDays *int
// Precondition is the condition to run the DAG.
Precondition any
// Preconditions is the condition to run the DAG.
Preconditions any
// maxActiveRuns is the maximum number of concurrent dag-runs.
MaxActiveRuns int
// MaxActiveSteps is the maximum number of concurrent steps.
MaxActiveSteps int
// Params is the default parameters for the steps.
Params any
// MaxCleanUpTimeSec is the maximum time in seconds to clean up the DAG.
// It is a wait time to kill the processes when it is requested to stop.
// If the time is exceeded, the process is killed.
MaxCleanUpTimeSec *int
// Tags is the tags for the DAG.
Tags any
// Queue is the name of the queue to assign this DAG to.
Queue string
// MaxOutputSize is the maximum size of the output for each step.
MaxOutputSize int
// OTel is the OpenTelemetry configuration.
OTel any
// WorkerSelector specifies required worker labels for execution.
WorkerSelector map[string]string
// Container is the container definition for the DAG.
Container *containerDef
// RunConfig contains configuration for controlling user interactions during DAG runs.
RunConfig *runConfigDef
// RegistryAuths maps registry hostnames to authentication configs.
// Can be either a JSON string or a map of registry to auth config.
RegistryAuths any
// SSH is the default SSH configuration for the DAG.
SSH *sshDef
// Secrets contains references to external secrets.
Secrets []secretRefDef
}
// handlerOnDef defines the steps to be executed on different events.
type handlerOnDef struct {
Init *stepDef // Step to execute before steps (after preconditions pass)
Failure *stepDef // Step to execute on failure
Success *stepDef // Step to execute on success
Abort *stepDef // Step to execute on abort (canonical field)
Cancel *stepDef // Step to execute on cancel (deprecated: use Abort instead)
Exit *stepDef // Step to execute on exit
}
// stepDef defines a step in the DAG.
type stepDef struct {
// Name is the name of the step.
Name string `yaml:"name,omitempty"`
// ID is the optional unique identifier for the step.
ID string `yaml:"id,omitempty"`
// Description is the description of the step.
Description string `yaml:"description,omitempty"`
// WorkingDir is the working directory of the step.
WorkingDir string `yaml:"workingDir,omitempty"`
// Dir is the working directory of the step.
// Deprecated: use WorkingDir instead
Dir string `yaml:"dir,omitempty"`
// Executor is the executor configuration.
Executor any `yaml:"executor,omitempty"`
// Command is the command to run (on shell).
Command any `yaml:"command,omitempty"`
// Shell is the shell to run the command. Default is `$SHELL` or `sh`.
// Can be a string (e.g., "bash -e") or an array (e.g., ["bash", "-e"]).
Shell any `yaml:"shell,omitempty"`
// ShellPackages is the list of packages to install.
// This is used only when the shell is `nix-shell`.
ShellPackages []string `yaml:"shellPackages,omitempty"`
// Script is the script to run.
Script string `yaml:"script,omitempty"`
// Stdout is the file to write the stdout.
Stdout string `yaml:"stdout,omitempty"`
// Stderr is the file to write the stderr.
Stderr string `yaml:"stderr,omitempty"`
// Output is the variable name to store the output.
Output string `yaml:"output,omitempty"`
// Depends is the list of steps to depend on.
Depends any `yaml:"depends,omitempty"` // string or []string
// ContinueOn is the condition to continue on.
// Can be a string ("skipped", "failed") or an object with detailed config.
ContinueOn any `yaml:"continueOn,omitempty"`
// RetryPolicy is the retry policy.
RetryPolicy *retryPolicyDef `yaml:"retryPolicy,omitempty"`
// RepeatPolicy is the repeat policy.
RepeatPolicy *repeatPolicyDef `yaml:"repeatPolicy,omitempty"`
// MailOnError is the flag to send mail on error.
MailOnError bool `yaml:"mailOnError,omitempty"`
// Precondition is the condition to run the step.
Precondition any `yaml:"precondition,omitempty"`
// Preconditions is the condition to run the step.
Preconditions any `yaml:"preconditions,omitempty"`
// SignalOnStop is the signal when the step is requested to stop.
// When it is empty, the same signal as the parent process is sent.
// It can be KILL when the process does not stop over the timeout.
SignalOnStop *string `yaml:"signalOnStop,omitempty"`
// Call is the name of a DAG to run as a sub dag-run.
Call string `yaml:"call,omitempty"`
// Run is the name of a DAG to run as a sub dag-run.
// Deprecated: use Call instead.
Run string `yaml:"run,omitempty"`
// Params specifies the parameters for the sub dag-run.
Params any `yaml:"params,omitempty"`
// Parallel specifies parallel execution configuration.
// Can be:
// - Direct array reference: parallel: ${ITEMS}
// - Static array: parallel: [item1, item2]
// - Object configuration: parallel: {items: ${ITEMS}, maxConcurrent: 5}
Parallel any `yaml:"parallel,omitempty"`
// WorkerSelector specifies required worker labels for execution.
WorkerSelector map[string]string `yaml:"workerSelector,omitempty"`
// Env specifies the environment variables for the step.
Env any `yaml:"env,omitempty"`
// TimeoutSec specifies the maximum runtime for the step in seconds.
TimeoutSec int `yaml:"timeoutSec,omitempty"`
}
// repeatPolicyDef defines the repeat policy for a step.
type repeatPolicyDef struct {
Repeat any `yaml:"repeat,omitempty"` // Flag to indicate if the step should be repeated, can be bool (legacy) or string ("while" or "until")
IntervalSec int `yaml:"intervalSec,omitempty"` // Interval in seconds to wait before repeating the step
Limit int `yaml:"limit,omitempty"` // Maximum number of times to repeat the step
Condition string `yaml:"condition,omitempty"` // Condition to check before repeating
Expected string `yaml:"expected,omitempty"` // Expected output to match before repeating
ExitCode []int `yaml:"exitCode,omitempty"` // List of exit codes to consider for repeating the step
Backoff any `yaml:"backoff,omitempty"` // Accepts bool or float
MaxIntervalSec int `yaml:"maxIntervalSec,omitempty"` // Maximum interval in seconds
}
// retryPolicyDef defines the retry policy for a step.
type retryPolicyDef struct {
Limit any `yaml:"limit,omitempty"`
IntervalSec any `yaml:"intervalSec,omitempty"`
ExitCode []int `yaml:"exitCode,omitempty"`
Backoff any `yaml:"backoff,omitempty"` // Accepts bool or float
MaxIntervalSec int `yaml:"maxIntervalSec,omitempty"`
}
// smtpConfigDef defines the SMTP configuration.
type smtpConfigDef struct {
Host string // SMTP host
Port any // SMTP port (can be string or number)
Username string // SMTP username
Password string // SMTP password
}
// mailConfigDef defines the mail configuration.
type mailConfigDef struct {
From string // Sender email address
To any // Recipient email address(es) - can be string or []string
Prefix string // Prefix for the email subject
AttachLogs bool // Flag to attach logs to the email
}
// mailOnDef defines the conditions to send mail.
type mailOnDef struct {
Failure bool // Send mail on failure
Success bool // Send mail on success
}
// containerDef defines the container configuration for the DAG.
type containerDef struct {
// Image is the container image to use.
Image string `yaml:"image,omitempty"`
// PullPolicy is the policy to pull the image (e.g., "Always", "IfNotPresent").
PullPolicy any `yaml:"pullPolicy,omitempty"`
// Env specifies environment variables for the container.
Env any `yaml:"env,omitempty"` // Can be a map or struct
// Volumes specifies the volumes to mount in the container.
Volumes []string `yaml:"volumes,omitempty"` // Map of volume names to volume definitions
// User is the user to run the container as.
User string `yaml:"user,omitempty"` // User to run the container as
// WorkingDir is the working directory inside the container.
WorkingDir string `yaml:"workingDir,omitempty"` // Working directory inside the container
// WorkDir is the working directory inside the container.
// Deprecated: use WorkingDir instead
WorkDir string `yaml:"workDir,omitempty"` // Working directory inside the container
// Platform specifies the platform for the container (e.g., "linux/amd64").
Platform string `yaml:"platform,omitempty"` // Platform for the container
// Ports specifies the ports to expose from the container.
Ports []string `yaml:"ports,omitempty"` // List of ports to expose
// Network is the network configuration for the container.
Network string `yaml:"network,omitempty"` // Network configuration for the container
// KeepContainer is the flag to keep the container after the DAG run.
KeepContainer bool `yaml:"keepContainer,omitempty"` // Keep the container after the DAG run
// Startup determines how the DAG-level container starts up.
Startup string `yaml:"startup,omitempty"`
// Command used when Startup == "command".
Command []string `yaml:"command,omitempty"`
// WaitFor readiness condition: running|healthy
WaitFor string `yaml:"waitFor,omitempty"`
// LogPattern regex to wait for in container logs.
LogPattern string `yaml:"logPattern,omitempty"`
// RestartPolicy: no|always|unless-stopped
RestartPolicy string `yaml:"restartPolicy,omitempty"`
}
// runConfigDef defines configuration for controlling user interactions during DAG runs.
type runConfigDef struct {
DisableParamEdit bool `yaml:"disableParamEdit,omitempty"` // Disable parameter editing when starting DAG
DisableRunIdEdit bool `yaml:"disableRunIdEdit,omitempty"` // Disable custom run ID specification
}
// sshDef defines the SSH configuration for the DAG.
type sshDef struct {
// User is the SSH user.
User string `yaml:"user,omitempty"`
// Host is the SSH host.
Host string `yaml:"host,omitempty"`
// Port is the SSH port (can be string or number).
Port any `yaml:"port,omitempty"`
// Key is the path to the SSH private key.
Key string `yaml:"key,omitempty"`
// Password is the SSH password.
Password string `yaml:"password,omitempty"`
// StrictHostKey enables strict host key checking. Defaults to true if not specified.
StrictHostKey *bool `yaml:"strictHostKey,omitempty"`
// KnownHostFile is the path to the known_hosts file. Defaults to ~/.ssh/known_hosts.
KnownHostFile string `yaml:"knownHostFile,omitempty"`
}
// secretRefDef defines a reference to an external secret.
type secretRefDef struct {
// Name is the environment variable name (required).
Name string `yaml:"name"`
// Provider specifies the secret backend (required).
Provider string `yaml:"provider"`
// Key is the provider-specific identifier (required).
Key string `yaml:"key"`
// Options contains provider-specific configuration (optional).
Options map[string]string `yaml:"options,omitempty"`
}

View File

@ -24,9 +24,15 @@ var (
ErrExecutorConfigValueMustBeMap = errors.New("executor.config value must be a map")
ErrExecutorHasInvalidKey = errors.New("executor has invalid key")
ErrExecutorConfigMustBeStringOrMap = errors.New("executor config must be string or map")
ErrContainerAndExecutorConflict = errors.New("cannot use both 'container' field and 'executor' field - the 'container' field already specifies the execution method")
ErrContainerAndScriptConflict = errors.New("cannot use 'script' field with 'container' field - use 'command' field instead")
ErrInvalidEnvValue = errors.New("env config should be map of strings or array of key=value formatted string")
ErrInvalidParamValue = errors.New("invalid parameter value")
ErrStepCommandIsEmpty = errors.New("step command is empty")
ErrStepCommandMustBeArrayOrString = errors.New("step command must be an array of strings or a string")
ErrTimeoutSecMustBeNonNegative = errors.New("timeoutSec must be >= 0")
ErrExecutorDoesNotSupportMultipleCmd = errors.New("executor does not support multiple commands")
ErrSubDAGAndExecutorConflict = errors.New("cannot use sub-DAG field ('call', 'run', or 'parallel') with 'executor' field")
ErrSubDAGAndCommandConflict = errors.New("cannot use sub-DAG field ('call', 'run', or 'parallel') with 'command' field")
ErrSubDAGAndScriptConflict = errors.New("cannot use sub-DAG field ('call', 'run', or 'parallel') with 'script' field")
)

View File

@ -14,6 +14,7 @@ import (
"dario.cat/mergo"
"github.com/dagu-org/dagu/internal/common/fileutil"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/spec/types"
"github.com/go-viper/mapstructure/v2"
"github.com/goccy/go-yaml"
@ -212,7 +213,7 @@ func LoadYAMLWithOpts(ctx context.Context, data []byte, opts BuildOpts) (*core.D
return nil, core.ErrorList{err}
}
return build(BuildContext{ctx: ctx, opts: opts}, def)
return def.build(BuildContext{ctx: ctx, opts: opts})
}
// LoadBaseConfig loads the global configuration from the given file.
@ -229,7 +230,7 @@ func LoadBaseConfig(ctx BuildContext, file string) (*core.DAG, error) {
return nil, err
}
// Decode the raw data into a config definition.
// Decode the raw data into a manifest.
def, err := decode(raw)
if err != nil {
return nil, core.ErrorList{err}
@ -238,12 +239,8 @@ func LoadBaseConfig(ctx BuildContext, file string) (*core.DAG, error) {
ctx = ctx.WithOpts(BuildOpts{
Flags: ctx.opts.Flags,
}).WithFile(file)
dag, err := build(ctx, def)
if err != nil {
return nil, core.ErrorList{err}
}
return dag, nil
return def.build(ctx)
}
// loadDAG loads the core.DAG from the given file.
@ -255,8 +252,8 @@ func loadDAG(ctx BuildContext, nameOrPath string) (*core.DAG, error) {
ctx = ctx.WithFile(filePath)
// Load base config definition if specified
var baseDef *definition
// Load base manifest if specified
var baseDef *dag
if !ctx.opts.Has(BuildFlagOnlyMetadata) && ctx.opts.Base != "" && fileutil.FileExists(ctx.opts.Base) {
raw, err := readYAMLFile(ctx.opts.Base)
if err != nil {
@ -308,7 +305,7 @@ func loadDAG(ctx BuildContext, nameOrPath string) (*core.DAG, error) {
}
// loadDAGsFromFile loads all DAGs from a multi-document YAML file
func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *definition) ([]*core.DAG, error) {
func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *dag) ([]*core.DAG, error) {
// Open the file
f, err := os.Open(filePath) //nolint:gosec
if err != nil {
@ -348,13 +345,13 @@ func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *definition) ([
// Update the context with the current document index
ctx.index = docIndex
// Decode the document into definition
// Decode the document into manifest
spec, err := decode(doc)
if err != nil {
return nil, fmt.Errorf("failed to decode document %d: %w", docIndex, err)
}
// Build a fresh base core.DAG from base definition if provided
// Build a fresh base core.DAG from base manifest if provided
var dest *core.DAG
if baseDef != nil {
// Build a new base core.DAG for this document
@ -363,7 +360,7 @@ func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *definition) ([
buildCtx.opts.Parameters = ""
buildCtx.opts.ParametersList = nil
// Build the base core.DAG
baseDAG, err := build(buildCtx, baseDef)
baseDAG, err := baseDef.build(buildCtx)
if err != nil {
return nil, fmt.Errorf("failed to build base core.DAG for document %d: %w", docIndex, err)
}
@ -380,7 +377,7 @@ func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *definition) ([
}
// Build the core.DAG from the current document
dag, err := build(ctx, spec)
dag, err := spec.build(ctx)
if err != nil {
return nil, fmt.Errorf("failed to build core.DAG in document %d: %w", docIndex, err)
}
@ -540,19 +537,75 @@ func unmarshalData(data []byte) (map[string]any, error) {
return cm, err
}
// decode decodes the configuration map into a configDefinition.
func decode(cm map[string]any) (*definition, error) {
c := new(definition)
// decode decodes the configuration map into a manifest.
func decode(cm map[string]any) (*dag, error) {
c := new(dag)
md, _ := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
ErrorUnused: true,
Result: c,
TagName: "",
DecodeHook: TypedUnionDecodeHook(),
})
err := md.Decode(cm)
return c, err
}
// TypedUnionDecodeHook returns a decode hook that handles our typed union types.
// It converts raw map[string]any values to the appropriate typed values.
func TypedUnionDecodeHook() mapstructure.DecodeHookFunc {
return func(_ reflect.Type, to reflect.Type, data any) (any, error) {
// Handle types.ShellValue
if to == reflect.TypeOf(types.ShellValue{}) {
return decodeViaYAML[types.ShellValue](data)
}
// Handle types.StringOrArray
if to == reflect.TypeOf(types.StringOrArray{}) {
return decodeViaYAML[types.StringOrArray](data)
}
// Handle types.ScheduleValue
if to == reflect.TypeOf(types.ScheduleValue{}) {
return decodeViaYAML[types.ScheduleValue](data)
}
// Handle types.EnvValue
if to == reflect.TypeOf(types.EnvValue{}) {
return decodeViaYAML[types.EnvValue](data)
}
// Handle types.ContinueOnValue
if to == reflect.TypeOf(types.ContinueOnValue{}) {
return decodeViaYAML[types.ContinueOnValue](data)
}
// Handle types.PortValue
if to == reflect.TypeOf(types.PortValue{}) {
return decodeViaYAML[types.PortValue](data)
}
// Handle types.LogOutputValue
if to == reflect.TypeOf(types.LogOutputValue{}) {
return decodeViaYAML[types.LogOutputValue](data)
}
return data, nil
}
}
// decodeViaYAML converts data to YAML and unmarshals it to the target type.
// This allows the custom UnmarshalYAML methods to be used.
func decodeViaYAML[T any](data any) (T, error) {
var result T
if data == nil {
return result, nil
}
// Convert the data to YAML bytes
yamlBytes, err := yaml.Marshal(data)
if err != nil {
return result, fmt.Errorf("failed to marshal to YAML: %w", err)
}
// Unmarshal using the custom UnmarshalYAML method
if err := yaml.Unmarshal(yamlBytes, &result); err != nil {
return result, fmt.Errorf("failed to unmarshal from YAML: %w", err)
}
return result, nil
}
// merge merges the source core.DAG into the destination DAG.
func merge(dst, src *core.DAG) error {
return mergo.Merge(dst, src, mergo.WithOverride,

View File

@ -15,7 +15,7 @@ import (
func TestLoad(t *testing.T) {
t.Parallel()
t.Run(("WithName"), func(t *testing.T) {
t.Run("WithName", func(t *testing.T) {
t.Parallel()
testDAG := createTempYAMLFile(t, `steps:
@ -26,43 +26,72 @@ func TestLoad(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "testDAG", dag.Name)
})
t.Run("InvalidPath", func(t *testing.T) {
t.Parallel()
// Use a non-existing file path
testDAG := "/tmp/non_existing_file_" + t.Name() + ".yaml"
_, err := spec.Load(context.Background(), testDAG)
require.Error(t, err)
})
t.Run("UnknownField", func(t *testing.T) {
t.Parallel()
// Error cases
errorTests := []struct {
name string
content string
useFile bool
errContains string
}{
{
name: "InvalidPath",
useFile: false,
},
{
name: "UnknownField",
content: "invalidKey: test\n",
useFile: true,
errContains: "has invalid keys: invalidKey",
},
{
name: "InvalidYAML",
content: "invalidyaml",
useFile: true,
errContains: "invalidyaml",
},
}
testDAG := createTempYAMLFile(t, `invalidKey: test
`)
_, err := spec.Load(context.Background(), testDAG)
require.Error(t, err)
require.Contains(t, err.Error(), "has invalid keys: invalidKey")
})
t.Run("InvalidYAML", func(t *testing.T) {
t.Parallel()
for _, tt := range errorTests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var testDAG string
if tt.useFile {
testDAG = createTempYAMLFile(t, tt.content)
} else {
testDAG = "/tmp/non_existing_file_" + t.Name() + ".yaml"
}
_, err := spec.Load(context.Background(), testDAG)
require.Error(t, err)
if tt.errContains != "" {
require.Contains(t, err.Error(), tt.errContains)
}
})
}
testDAG := createTempYAMLFile(t, `invalidyaml`)
_, err := spec.Load(context.Background(), testDAG)
require.Error(t, err)
require.Contains(t, err.Error(), "invalidyaml")
})
t.Run("MetadataOnly", func(t *testing.T) {
t.Parallel()
testDAG := createTempYAMLFile(t, `steps:
testDAG := createTempYAMLFile(t, `
logDir: /var/log/dagu
histRetentionDays: 90
maxCleanUpTimeSec: 60
mailOn:
failure: true
steps:
- name: "1"
command: "true"
`)
dag, err := spec.Load(context.Background(), testDAG, spec.OnlyMetadata())
require.NoError(t, err)
// Steps should not be loaded in metadata-only mode
require.Empty(t, dag.Steps)
// Check if the metadata is loaded correctly
require.Len(t, dag.Steps, 0)
// Non-metadata fields from YAML should NOT be populated in metadata-only mode
assert.Empty(t, dag.LogDir, "LogDir should be empty in metadata-only mode")
assert.Nil(t, dag.MailOn, "MailOn should be nil in metadata-only mode")
// Defaults are still applied by InitializeDefaults (not from YAML)
assert.Equal(t, 30, dag.HistRetentionDays, "HistRetentionDays should be default (30), not YAML value (90)")
assert.Equal(t, 5*time.Second, dag.MaxCleanUpTime, "MaxCleanUpTime should be default (5s), not YAML value (60s)")
})
t.Run("DefaultConfig", func(t *testing.T) {
t.Parallel()
@ -84,7 +113,8 @@ func TestLoad(t *testing.T) {
// Step level
require.Len(t, dag.Steps, 1)
assert.Equal(t, "1", dag.Steps[0].Name, "1")
assert.Equal(t, "true", dag.Steps[0].Command, "true")
require.Len(t, dag.Steps[0].Commands, 1)
assert.Equal(t, "true", dag.Steps[0].Commands[0].Command)
})
t.Run("OverrideConfig", func(t *testing.T) {
t.Parallel()
@ -190,26 +220,45 @@ steps:
func TestLoadYAML(t *testing.T) {
t.Parallel()
const testDAG = `steps:
tests := []struct {
name string
input string
wantErr bool
wantName string
wantCommand string
}{
{
name: "ValidYAMLData",
input: `steps:
- name: "1"
command: "true"
`
t.Run("ValidYAMLData", func(t *testing.T) {
t.Parallel()
`,
wantName: "1",
wantCommand: "true",
},
{
name: "InvalidYAMLData",
input: "invalid",
wantErr: true,
},
}
ret, err := spec.LoadYAMLWithOpts(context.Background(), []byte(testDAG), spec.BuildOpts{})
require.NoError(t, err)
step := ret.Steps[0]
require.Equal(t, "1", step.Name)
require.Equal(t, "true", step.Command)
})
t.Run("InvalidYAMLData", func(t *testing.T) {
t.Parallel()
_, err := spec.LoadYAMLWithOpts(context.Background(), []byte(`invalid`), spec.BuildOpts{})
require.Error(t, err)
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ret, err := spec.LoadYAMLWithOpts(context.Background(), []byte(tt.input), spec.BuildOpts{})
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Len(t, ret.Steps, 1)
assert.Equal(t, tt.wantName, ret.Steps[0].Name)
require.Len(t, ret.Steps[0].Commands, 1)
assert.Equal(t, tt.wantCommand, ret.Steps[0].Commands[0].Command)
})
}
}
func TestLoadYAMLWithNameOption(t *testing.T) {
@ -228,7 +277,8 @@ steps:
step := ret.Steps[0]
require.Equal(t, "1", step.Name)
require.Equal(t, "true", step.Command)
require.Len(t, step.Commands, 1)
require.Equal(t, "true", step.Commands[0].Command)
}
// createTempYAMLFile creates a temporary YAML file with the given content
@ -295,7 +345,8 @@ steps:
assert.Equal(t, "transform-data", transformDAG.Name)
assert.Len(t, transformDAG.Steps, 1)
assert.Equal(t, "transform", transformDAG.Steps[0].Name)
assert.Equal(t, "transform.py", transformDAG.Steps[0].Command)
require.Len(t, transformDAG.Steps[0].Commands, 1)
assert.Equal(t, "transform.py", transformDAG.Steps[0].Commands[0].Command)
// Check archive-results sub DAG
_, exists = dag.LocalDAGs["archive-results"]
@ -304,7 +355,8 @@ steps:
assert.Equal(t, "archive-results", archiveDAG.Name)
assert.Len(t, archiveDAG.Steps, 1)
assert.Equal(t, "archive", archiveDAG.Steps[0].Name)
assert.Equal(t, "archive.sh", archiveDAG.Steps[0].Command)
require.Len(t, archiveDAG.Steps[0].Commands, 1)
assert.Equal(t, "archive.sh", archiveDAG.Steps[0].Commands[0].Command)
})
t.Run("MultiDAGWithBaseConfig", func(t *testing.T) {
@ -385,11 +437,15 @@ steps:
assert.Nil(t, dag.LocalDAGs) // No sub DAGs for single DAG file
})
t.Run("DuplicateSubDAGNames", func(t *testing.T) {
t.Parallel()
// Multi-DAG file with duplicate names
multiDAGContent := `steps:
// MultiDAG error cases
multiDAGErrorTests := []struct {
name string
content string
errContains string
}{
{
name: "DuplicateSubDAGNames",
content: `steps:
- name: step1
command: echo "main"
@ -404,20 +460,12 @@ name: duplicate-name
steps:
- name: step1
command: echo "second"
`
tmpFile := createTempYAMLFile(t, multiDAGContent)
// Load should fail due to duplicate names
_, err := spec.Load(context.Background(), tmpFile)
require.Error(t, err)
assert.Contains(t, err.Error(), "duplicate DAG name")
})
t.Run("SubDAGWithoutName", func(t *testing.T) {
t.Parallel()
// Multi-DAG file where sub DAG has no name
multiDAGContent := `steps:
`,
errContains: "duplicate DAG name",
},
{
name: "SubDAGWithoutName",
content: `steps:
- name: step1
command: echo "main"
@ -425,14 +473,20 @@ steps:
steps:
- name: step1
command: echo "unnamed"
`
tmpFile := createTempYAMLFile(t, multiDAGContent)
`,
errContains: "must have a name",
},
}
// Load should fail because sub DAG has no name
_, err := spec.Load(context.Background(), tmpFile)
require.Error(t, err)
assert.Contains(t, err.Error(), "must have a name")
})
for _, tt := range multiDAGErrorTests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tmpFile := createTempYAMLFile(t, tt.content)
_, err := spec.Load(context.Background(), tmpFile)
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
})
}
t.Run("EmptyDocumentSeparator", func(t *testing.T) {
t.Parallel()
@ -615,3 +669,147 @@ steps:
assert.Equal(t, explicitDir, dag.WorkingDir)
})
}
func TestLoadWithLoaderOptions(t *testing.T) {
t.Parallel()
t.Run("WithDAGsDir", func(t *testing.T) {
t.Parallel()
// Create a DAGs directory
dagsDir := t.TempDir()
// Create a sub-DAG file
subDAGPath := filepath.Join(dagsDir, "subdag.yaml")
require.NoError(t, os.WriteFile(subDAGPath, []byte(`
steps:
- name: sub-step
command: echo sub
`), 0644))
// Create main DAG that calls the sub-DAG
mainDAG := createTempYAMLFile(t, `
steps:
- name: main-step
command: echo main
`)
// Load with WithDAGsDir
dag, err := spec.Load(context.Background(), mainDAG, spec.WithDAGsDir(dagsDir))
require.NoError(t, err)
require.NotNil(t, dag)
})
t.Run("WithAllowBuildErrors", func(t *testing.T) {
t.Parallel()
testDAG := createTempYAMLFile(t, `
steps:
- name: test
command: echo test
depends:
- nonexistent-step
`)
// Without AllowBuildErrors, this would fail
_, err := spec.Load(context.Background(), testDAG)
require.Error(t, err)
// With AllowBuildErrors, it should succeed but capture errors
dag, err := spec.Load(context.Background(), testDAG, spec.WithAllowBuildErrors())
require.NoError(t, err)
require.NotNil(t, dag)
assert.NotEmpty(t, dag.BuildErrors)
})
t.Run("SkipSchemaValidation", func(t *testing.T) {
t.Parallel()
testDAG := createTempYAMLFile(t, `
params:
schema: "nonexistent-schema.json"
values:
foo: bar
steps:
- name: test
command: echo test
`)
// Without SkipSchemaValidation, this would fail due to missing schema
_, err := spec.Load(context.Background(), testDAG)
require.Error(t, err)
// With SkipSchemaValidation, it should succeed
dag, err := spec.Load(context.Background(), testDAG, spec.SkipSchemaValidation())
require.NoError(t, err)
require.NotNil(t, dag)
})
t.Run("WithSkipBaseHandlers", func(t *testing.T) {
t.Parallel()
// Create base config with handlers
baseDir := t.TempDir()
baseConfig := filepath.Join(baseDir, "base.yaml")
require.NoError(t, os.WriteFile(baseConfig, []byte(`
handlerOn:
success:
command: echo base-success
`), 0644))
testDAG := createTempYAMLFile(t, `
steps:
- name: test
command: echo test
`)
// Load with base config but skip base handlers
dag, err := spec.Load(context.Background(), testDAG,
spec.WithBaseConfig(baseConfig),
spec.WithSkipBaseHandlers())
require.NoError(t, err)
require.NotNil(t, dag)
// The base success handler should not be present
assert.Nil(t, dag.HandlerOn.Success)
})
t.Run("WithParamsAsList", func(t *testing.T) {
t.Parallel()
testDAG := createTempYAMLFile(t, `
params: KEY1 KEY2
steps:
- name: test
command: echo $KEY1 $KEY2
`)
// Load with params as list
dag, err := spec.Load(context.Background(), testDAG,
spec.WithParams([]string{"KEY1=value1", "KEY2=value2"}))
require.NoError(t, err)
require.NotNil(t, dag)
// Check that params were applied
found := 0
for _, p := range dag.Params {
if p == "KEY1=value1" || p == "KEY2=value2" {
found++
}
}
assert.Equal(t, 2, found, "Both params should be applied")
})
}
// TestLoadWithoutEval tests the WithoutEval loader option
// This test cannot be parallel because it uses t.Setenv
func TestLoadWithoutEval(t *testing.T) {
t.Setenv("TEST_VAR", "should-not-expand")
testDAG := createTempYAMLFile(t, `
env:
- MY_VAR: "${TEST_VAR}"
steps:
- name: test
command: echo test
`)
dag, err := spec.Load(context.Background(), testDAG, spec.WithoutEval())
require.NoError(t, err)
// With NoEval, environment variables should not be expanded
assert.Contains(t, dag.Env, "MY_VAR=${TEST_VAR}")
}

View File

@ -3,88 +3,17 @@ package spec
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"github.com/dagu-org/dagu/internal/common/cmdutil"
"github.com/dagu-org/dagu/internal/common/fileutil"
"github.com/dagu-org/dagu/internal/core"
"github.com/google/jsonschema-go/jsonschema"
)
// buildParams builds the parameters for the DAG.
func buildParams(ctx BuildContext, spec *definition, dag *core.DAG) error {
var (
paramPairs []paramPair
envs []string
)
if err := parseParams(ctx, spec.Params, &paramPairs, &envs); err != nil {
return err
}
// Create default parameters string in the form of "key=value key=value ..."
var paramsToJoin []string
for _, paramPair := range paramPairs {
paramsToJoin = append(paramsToJoin, paramPair.Escaped())
}
dag.DefaultParams = strings.Join(paramsToJoin, " ")
if ctx.opts.Parameters != "" {
// Parse the parameters from the command line and override the default parameters
var (
overridePairs []paramPair
overrideEnvs []string
)
if err := parseParams(ctx, ctx.opts.Parameters, &overridePairs, &overrideEnvs); err != nil {
return err
}
// Override the default parameters with the command line parameters
overrideParams(&paramPairs, overridePairs)
}
if len(ctx.opts.ParametersList) > 0 {
var (
overridePairs []paramPair
overrideEnvs []string
)
if err := parseParams(ctx, ctx.opts.ParametersList, &overridePairs, &overrideEnvs); err != nil {
return err
}
// Override the default parameters with the command line parameters
overrideParams(&paramPairs, overridePairs)
}
// Validate the parameters against a resolved schema (if declared)
if !ctx.opts.Has(BuildFlagSkipSchemaValidation) {
if resolvedSchema, err := resolveSchemaFromParams(spec.Params, spec.WorkingDir, dag.Location); err != nil {
return fmt.Errorf("failed to get JSON schema: %w", err)
} else if resolvedSchema != nil {
updatedPairs, err := validateParams(paramPairs, resolvedSchema)
if err != nil {
return err
}
paramPairs = updatedPairs
}
}
for _, paramPair := range paramPairs {
dag.Params = append(dag.Params, paramPair.String())
}
dag.Env = append(dag.Env, envs...)
return nil
}
func validateParams(paramPairs []paramPair, schema *jsonschema.Resolved) ([]paramPair, error) {
// Convert paramPairs to a map for validation
paramMap := make(map[string]any)
@ -127,134 +56,6 @@ func validateParams(paramPairs []paramPair, schema *jsonschema.Resolved) ([]para
return updatedPairs, nil
}
// Schema Ref can be a local file (relative or absolute paths), or a remote URL
func getSchemaFromRef(workingDir string, dagLocation string, schemaRef string) (*jsonschema.Resolved, error) {
var schemaData []byte
var err error
// Check if it's a URL or file path
if strings.HasPrefix(schemaRef, "http://") || strings.HasPrefix(schemaRef, "https://") {
schemaData, err = loadSchemaFromURL(schemaRef)
} else {
schemaData, err = loadSchemaFromFile(workingDir, dagLocation, schemaRef)
}
if err != nil {
return nil, fmt.Errorf("failed to load schema from %s: %w", schemaRef, err)
}
var schema jsonschema.Schema
if err := json.Unmarshal(schemaData, &schema); err != nil {
return nil, fmt.Errorf("failed to parse schema JSON: %w", err)
}
resolveOptions := &jsonschema.ResolveOptions{
ValidateDefaults: true,
}
resolvedSchema, err := schema.Resolve(resolveOptions)
if err != nil {
return nil, fmt.Errorf("failed to resolve schema: %w", err)
}
return resolvedSchema, nil
}
// loadSchemaFromURL loads a JSON schema from a URL.
func loadSchemaFromURL(schemaURL string) (data []byte, err error) {
// Validate URL to prevent potential security issues (and satisfy linter :P)
parsedURL, err := url.Parse(schemaURL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return nil, fmt.Errorf("unsupported URL scheme: %s", parsedURL.Scheme)
}
req, err := http.NewRequest("GET", schemaURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
}
data, err = io.ReadAll(resp.Body)
return data, err
}
// loadSchemaFromFile loads a JSON schema from a file path.
func loadSchemaFromFile(workingDir string, dagLocation string, filePath string) ([]byte, error) {
// Try to resolve the schema file path in the following order:
// 1) Current working directory (default ResolvePath behavior)
// 2) DAG's workingDir value
// 3) Directory of the DAG file (where it was loaded from)
var tried []string
// Attempts a candidate by joining base and filePath (if base provided),
// resolving env/tilde + absolute path, checking existence, and reading.
tryCandidate := func(label, base string) ([]byte, string, error) {
var candidate string
if strings.TrimSpace(base) == "" {
candidate = filePath
} else {
candidate = filepath.Join(base, filePath)
}
resolved, err := fileutil.ResolvePath(candidate)
if err != nil {
tried = append(tried, fmt.Sprintf("%s: resolve error: %v", label, err))
return nil, "", err
}
if !fileutil.FileExists(resolved) {
tried = append(tried, fmt.Sprintf("%s: %s", label, resolved))
return nil, resolved, os.ErrNotExist
}
data, err := os.ReadFile(resolved) // #nosec G304 - validated path
if err != nil {
tried = append(tried, fmt.Sprintf("%s: %s (read error: %v)", label, resolved, err))
return nil, resolved, err
}
return data, resolved, nil
}
// 1) As provided (CWD/env/tilde expansion handled by ResolvePath)
if data, _, err := tryCandidate("cwd", ""); err == nil {
return data, nil
}
// 2) From DAG's workingDir value if present
if wd := strings.TrimSpace(workingDir); wd != "" {
if data, _, err := tryCandidate(fmt.Sprintf("workingDir(%s)", wd), wd); err == nil {
return data, nil
}
}
// 3) From the directory of the DAG file used to build
if dagLocation != "" {
base := filepath.Dir(dagLocation)
if data, _, err := tryCandidate(fmt.Sprintf("dagDir(%s)", base), base); err == nil {
return data, nil
}
}
if len(tried) == 0 {
return nil, fmt.Errorf("failed to resolve schema file path: %s (no candidates)", filePath)
}
return nil, fmt.Errorf("schema file not found for %q; tried %s", filePath, strings.Join(tried, ", "))
}
func overrideParams(paramPairs *[]paramPair, override []paramPair) {
// Override the default parameters with the command line parameters (and satisfy linter :P)
pairsIndex := make(map[string]int)
@ -514,57 +315,3 @@ func (p paramPair) Escaped() string {
}
return fmt.Sprintf("%q", p.Value)
}
// extractSchemaReference extracts the schema reference from a params map.
// Returns the schema reference as a string if present and valid, empty string otherwise.
func extractSchemaReference(params any) string {
paramsMap, ok := params.(map[string]any)
if !ok {
return ""
}
schemaRef, hasSchema := paramsMap["schema"]
if !hasSchema {
return ""
}
schemaRefStr, ok := schemaRef.(string)
if !ok {
return ""
}
return schemaRefStr
}
// resolveSchemaFromParams extracts a schema reference from params and resolves it.
// Returns (nil, nil) if no schema is declared.
func resolveSchemaFromParams(params any, workingDir, dagLocation string) (*jsonschema.Resolved, error) {
schemaRef := extractSchemaReference(params)
if schemaRef == "" {
return nil, nil
}
return getSchemaFromRef(workingDir, dagLocation, schemaRef)
}
// buildStepParams parses the params field in the step definition.
// Params are converted to map[string]string and stored in step.Params
func buildStepParams(ctx StepBuildContext, def stepDef, step *core.Step) error {
if def.Params == nil {
return nil
}
// Parse params using existing parseParamValue function
paramPairs, err := parseParamValue(ctx.BuildContext, def.Params)
if err != nil {
return core.NewValidationError("params", def.Params, err)
}
// Convert to map[string]string
paramsData := make(map[string]string)
for _, pair := range paramPairs {
paramsData[pair.Name] = pair.Value
}
step.Params = core.NewSimpleParams(paramsData)
return nil
}

View File

@ -0,0 +1,662 @@
package spec
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParamPairString(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pair paramPair
expected string
}{
{
name: "NamedParam",
pair: paramPair{Name: "foo", Value: "bar"},
expected: "foo=bar",
},
{
name: "PositionalParam",
pair: paramPair{Name: "", Value: "value"},
expected: "value",
},
{
name: "EmptyValue",
pair: paramPair{Name: "key", Value: ""},
expected: "key=",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.expected, tt.pair.String())
})
}
}
func TestParamPairEscaped(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pair paramPair
expected string
}{
{
name: "NamedParam",
pair: paramPair{Name: "foo", Value: "bar"},
expected: `foo="bar"`,
},
{
name: "PositionalParam",
pair: paramPair{Name: "", Value: "value"},
expected: `"value"`,
},
{
name: "ValueWithSpaces",
pair: paramPair{Name: "msg", Value: "hello world"},
expected: `msg="hello world"`,
},
{
name: "ValueWithQuotes",
pair: paramPair{Name: "json", Value: `{"key":"value"}`},
expected: `json="{\"key\":\"value\"}"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.expected, tt.pair.Escaped())
})
}
}
func TestParseStringParams(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
tests := []struct {
name string
input string
expected []paramPair
}{
{
name: "SinglePositionalParam",
input: "value",
expected: []paramPair{{Name: "", Value: "value"}},
},
{
name: "SingleNamedParam",
input: "key=value",
expected: []paramPair{{Name: "key", Value: "value"}},
},
{
name: "MultipleNamedParams",
input: "foo=bar baz=qux",
expected: []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
},
},
{
name: "MixedParams",
input: "positional key=value",
expected: []paramPair{
{Name: "", Value: "positional"},
{Name: "key", Value: "value"},
},
},
{
name: "QuotedValue",
input: `msg="hello world"`,
expected: []paramPair{{Name: "msg", Value: "hello world"}},
},
{
name: "QuotedValueWithEscapedQuotes",
input: `msg="say \"hello\""`,
expected: []paramPair{{Name: "msg", Value: `say "hello\`}},
},
{
name: "EmptyString",
input: "",
expected: nil,
},
{
name: "MultiplePositionalParams",
input: "one two three",
expected: []paramPair{
{Name: "", Value: "one"},
{Name: "", Value: "two"},
{Name: "", Value: "three"},
},
},
{
name: "BacktickValue",
input: "cmd=`echo hello`",
expected: []paramPair{{Name: "cmd", Value: "`echo hello`"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := parseStringParams(ctx, tt.input)
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseStringParamsWithEval(t *testing.T) {
t.Run("BacktickCommandSubstitution", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{},
}
result, err := parseStringParams(ctx, "val=`echo hello`")
require.NoError(t, err)
require.Len(t, result, 1)
assert.Equal(t, "val", result[0].Name)
assert.Equal(t, "hello", result[0].Value)
})
}
func TestParseListParams(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
tests := []struct {
name string
input []string
expected []paramPair
}{
{
name: "EmptyList",
input: []string{},
expected: nil,
},
{
name: "SingleItem",
input: []string{"foo=bar"},
expected: []paramPair{{Name: "foo", Value: "bar"}},
},
{
name: "MultipleItems",
input: []string{"foo=bar", "baz=qux"},
expected: []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
},
},
{
name: "ItemsWithMultipleParams",
input: []string{"a=1 b=2", "c=3"},
expected: []paramPair{
{Name: "a", Value: "1"},
{Name: "b", Value: "2"},
{Name: "c", Value: "3"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := parseListParams(ctx, tt.input)
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseMapParams(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
t.Run("EmptySlice", func(t *testing.T) {
t.Parallel()
result, err := parseMapParams(ctx, []any{})
require.NoError(t, err)
assert.Nil(t, result)
})
t.Run("SingleMap", func(t *testing.T) {
t.Parallel()
input := []any{
map[string]any{"foo": "bar"},
}
result, err := parseMapParams(ctx, input)
require.NoError(t, err)
assert.Equal(t, []paramPair{{Name: "foo", Value: "bar"}}, result)
})
t.Run("MapWithMultipleKeys", func(t *testing.T) {
t.Parallel()
input := []any{
map[string]any{"a": "1", "b": "2"},
}
result, err := parseMapParams(ctx, input)
require.NoError(t, err)
// Keys are sorted alphabetically
assert.Equal(t, []paramPair{
{Name: "a", Value: "1"},
{Name: "b", Value: "2"},
}, result)
})
t.Run("MultipleMaps", func(t *testing.T) {
t.Parallel()
input := []any{
map[string]any{"foo": "bar"},
map[string]any{"baz": "qux"},
}
result, err := parseMapParams(ctx, input)
require.NoError(t, err)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
}, result)
})
t.Run("MixedMapAndString", func(t *testing.T) {
t.Parallel()
input := []any{
map[string]any{"foo": "bar"},
"baz=qux",
}
result, err := parseMapParams(ctx, input)
require.NoError(t, err)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
}, result)
})
t.Run("IntegerValue", func(t *testing.T) {
t.Parallel()
input := []any{
map[string]any{"count": 42},
}
result, err := parseMapParams(ctx, input)
require.NoError(t, err)
assert.Equal(t, []paramPair{{Name: "count", Value: "42"}}, result)
})
t.Run("BooleanValue", func(t *testing.T) {
t.Parallel()
input := []any{
map[string]any{"debug": true},
}
result, err := parseMapParams(ctx, input)
require.NoError(t, err)
assert.Equal(t, []paramPair{{Name: "debug", Value: "true"}}, result)
})
t.Run("InvalidType", func(t *testing.T) {
t.Parallel()
input := []any{123}
_, err := parseMapParams(ctx, input)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid parameter value")
})
}
func TestParseParamValue(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
t.Run("Nil", func(t *testing.T) {
t.Parallel()
result, err := parseParamValue(ctx, nil)
require.NoError(t, err)
assert.Nil(t, result)
})
t.Run("String", func(t *testing.T) {
t.Parallel()
result, err := parseParamValue(ctx, "foo=bar baz=qux")
require.NoError(t, err)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
}, result)
})
t.Run("StringSlice", func(t *testing.T) {
t.Parallel()
result, err := parseParamValue(ctx, []string{"foo=bar", "baz=qux"})
require.NoError(t, err)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
}, result)
})
t.Run("AnySlice", func(t *testing.T) {
t.Parallel()
result, err := parseParamValue(ctx, []any{
map[string]any{"foo": "bar"},
})
require.NoError(t, err)
assert.Equal(t, []paramPair{{Name: "foo", Value: "bar"}}, result)
})
t.Run("MapWithoutSchema", func(t *testing.T) {
t.Parallel()
result, err := parseParamValue(ctx, map[string]any{
"foo": "bar",
"baz": "qux",
})
require.NoError(t, err)
// Keys are sorted
assert.Len(t, result, 2)
})
t.Run("MapWithSchemaNoValues", func(t *testing.T) {
t.Parallel()
result, err := parseParamValue(ctx, map[string]any{
"schema": "schema.json",
})
require.NoError(t, err)
assert.Empty(t, result)
})
t.Run("MapWithSchemaAndValues", func(t *testing.T) {
t.Parallel()
result, err := parseParamValue(ctx, map[string]any{
"schema": "schema.json",
"values": map[string]any{"foo": "bar"},
})
require.NoError(t, err)
assert.Equal(t, []paramPair{{Name: "foo", Value: "bar"}}, result)
})
t.Run("InvalidType", func(t *testing.T) {
t.Parallel()
_, err := parseParamValue(ctx, 123)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid parameter value")
})
}
func TestOverrideParams(t *testing.T) {
t.Parallel()
t.Run("OverrideByName", func(t *testing.T) {
t.Parallel()
params := []paramPair{
{Name: "foo", Value: "original"},
{Name: "bar", Value: "keep"},
}
override := []paramPair{
{Name: "foo", Value: "overridden"},
}
overrideParams(&params, override)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "overridden"},
{Name: "bar", Value: "keep"},
}, params)
})
t.Run("AddNewNamedParam", func(t *testing.T) {
t.Parallel()
params := []paramPair{
{Name: "foo", Value: "bar"},
}
override := []paramPair{
{Name: "baz", Value: "qux"},
}
overrideParams(&params, override)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
}, params)
})
t.Run("OverrideByPosition", func(t *testing.T) {
t.Parallel()
params := []paramPair{
{Name: "", Value: "first"},
{Name: "", Value: "second"},
}
override := []paramPair{
{Name: "", Value: "new-first"},
}
overrideParams(&params, override)
assert.Equal(t, []paramPair{
{Name: "", Value: "new-first"},
{Name: "", Value: "second"},
}, params)
})
t.Run("AddPositionalBeyondLength", func(t *testing.T) {
t.Parallel()
params := []paramPair{
{Name: "", Value: "first"},
}
override := []paramPair{
{Name: "", Value: "new-first"},
{Name: "", Value: "new-second"},
}
overrideParams(&params, override)
assert.Equal(t, []paramPair{
{Name: "", Value: "new-first"},
{Name: "", Value: "new-second"},
}, params)
})
t.Run("EmptyOverride", func(t *testing.T) {
t.Parallel()
params := []paramPair{
{Name: "foo", Value: "bar"},
}
overrideParams(&params, []paramPair{})
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
}, params)
})
t.Run("EmptyParams", func(t *testing.T) {
t.Parallel()
params := []paramPair{}
override := []paramPair{
{Name: "foo", Value: "bar"},
}
overrideParams(&params, override)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
}, params)
})
}
func TestParseParams(t *testing.T) {
t.Parallel()
t.Run("SimpleParams", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
var params []paramPair
var envs []string
err := parseParams(ctx, "foo=bar baz=qux", &params, &envs)
require.NoError(t, err)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
}, params)
// NoEval flag prevents env vars from being added
assert.Empty(t, envs)
})
t.Run("PositionalParamsGetNames", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
var params []paramPair
var envs []string
err := parseParams(ctx, "one two three", &params, &envs)
require.NoError(t, err)
// Positional params get numbered names
assert.Equal(t, []paramPair{
{Name: "1", Value: "one"},
{Name: "2", Value: "two"},
{Name: "3", Value: "three"},
}, params)
})
t.Run("WithEvalAddsEnvs", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{},
}
var params []paramPair
var envs []string
err := parseParams(ctx, "foo=bar baz=qux", &params, &envs)
require.NoError(t, err)
assert.Equal(t, []paramPair{
{Name: "foo", Value: "bar"},
{Name: "baz", Value: "qux"},
}, params)
assert.Equal(t, []string{"foo=bar", "baz=qux"}, envs)
})
t.Run("VariableReference", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{},
}
var params []paramPair
var envs []string
// Second param references first param
err := parseParams(ctx, []any{
map[string]any{"BASE": "/opt"},
map[string]any{"PATH_VAR": "${BASE}/bin"},
}, &params, &envs)
require.NoError(t, err)
assert.Equal(t, "/opt", params[0].Value)
assert.Equal(t, "/opt/bin", params[1].Value)
})
t.Run("NilInput", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
var params []paramPair
var envs []string
err := parseParams(ctx, nil, &params, &envs)
require.NoError(t, err)
assert.Empty(t, params)
assert.Empty(t, envs)
})
t.Run("InvalidInput", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
var params []paramPair
var envs []string
err := parseParams(ctx, 123, &params, &envs)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid parameter value")
})
}
func TestEvalParamValue(t *testing.T) {
t.Run("SimpleValue", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
}
result, err := evalParamValue(ctx, "hello", nil)
require.NoError(t, err)
assert.Equal(t, "hello", result)
})
t.Run("WithAccumulatedVars", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
}
vars := map[string]string{"BASE": "/opt"}
result, err := evalParamValue(ctx, "${BASE}/bin", vars)
require.NoError(t, err)
assert.Equal(t, "/opt/bin", result)
})
t.Run("WithBuildEnv", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
buildEnv: map[string]string{"ENV_VAR": "value"},
}
result, err := evalParamValue(ctx, "${ENV_VAR}", nil)
require.NoError(t, err)
assert.Equal(t, "value", result)
})
t.Run("AccumulatedVarsPrecedence", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
buildEnv: map[string]string{"VAR": "from-env"},
}
vars := map[string]string{"VAR": "from-accumulated"}
result, err := evalParamValue(ctx, "${VAR}", vars)
require.NoError(t, err)
// Accumulated vars should take precedence (added first to options)
assert.Equal(t, "from-accumulated", result)
})
}

View File

@ -26,78 +26,3 @@ func buildScheduler(values []string) ([]core.Schedule, error) {
return ret, nil
}
// parseScheduleMap parses the schedule map and populates the starts, stops,
// and restarts slices. Each key in the map must be either "start", "stop", or
// "restart". The value can be Case 1 or Case 2.
//
// Case 1: The value is a string
// Case 2: The value is an array of strings
//
// Example:
// ```yaml
// schedule:
//
// start: "0 1 * * *"
// stop: "0 18 * * *"
// restart:
// - "0 1 * * *"
// - "0 18 * * *"
//
// ```
func parseScheduleMap(
scheduleMap map[string]any, starts, stops, restarts *[]string,
) error {
for key, v := range scheduleMap {
var values []string
switch v := v.(type) {
case string:
// Case 1. schedule is a string.
values = append(values, v)
case []any:
// Case 2. schedule is an array of strings.
// Append all the schedules to the values slice.
for _, s := range v {
s, ok := s.(string)
if !ok {
return core.NewValidationError("schedule", s, ErrScheduleMustBeStringOrArray)
}
values = append(values, s)
}
}
var targets *[]string
switch scheduleKey(key) {
case scheduleKeyStart:
targets = starts
case scheduleKeyStop:
targets = stops
case scheduleKeyRestart:
targets = restarts
}
for _, v := range values {
if _, err := cronParser.Parse(v); err != nil {
return core.NewValidationError("schedule", v, fmt.Errorf("%w: %s", ErrInvalidSchedule, err))
}
*targets = append(*targets, v)
}
}
return nil
}
type scheduleKey string
const (
scheduleKeyStart scheduleKey = "start"
scheduleKeyStop scheduleKey = "stop"
scheduleKeyRestart scheduleKey = "restart"
)

View File

@ -0,0 +1,174 @@
package spec
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"github.com/dagu-org/dagu/internal/common/fileutil"
"github.com/google/jsonschema-go/jsonschema"
)
// resolveSchemaFromParams extracts a schema reference from params and resolves it.
// Returns (nil, nil) if no schema is declared.
func resolveSchemaFromParams(params any, workingDir, dagLocation string) (*jsonschema.Resolved, error) {
schemaRef := extractSchemaReference(params)
if schemaRef == "" {
return nil, nil
}
return getSchemaFromRef(workingDir, dagLocation, schemaRef)
}
// Schema Ref can be a local file (relative or absolute paths), or a remote URL
func getSchemaFromRef(workingDir string, dagLocation string, schemaRef string) (*jsonschema.Resolved, error) {
var schemaData []byte
var err error
// Check if it's a URL or file path
if strings.HasPrefix(schemaRef, "http://") || strings.HasPrefix(schemaRef, "https://") {
schemaData, err = loadSchemaFromURL(schemaRef)
} else {
schemaData, err = loadSchemaFromFile(workingDir, dagLocation, schemaRef)
}
if err != nil {
return nil, fmt.Errorf("failed to load schema from %s: %w", schemaRef, err)
}
var schema jsonschema.Schema
if err := json.Unmarshal(schemaData, &schema); err != nil {
return nil, fmt.Errorf("failed to parse schema JSON: %w", err)
}
resolveOptions := &jsonschema.ResolveOptions{
ValidateDefaults: true,
}
resolvedSchema, err := schema.Resolve(resolveOptions)
if err != nil {
return nil, fmt.Errorf("failed to resolve schema: %w", err)
}
return resolvedSchema, nil
}
// loadSchemaFromURL loads a JSON schema from a URL.
func loadSchemaFromURL(schemaURL string) (data []byte, err error) {
// Validate URL to prevent potential security issues (and satisfy linter :P)
parsedURL, err := url.Parse(schemaURL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return nil, fmt.Errorf("unsupported URL scheme: %s", parsedURL.Scheme)
}
req, err := http.NewRequest("GET", schemaURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil && err == nil {
err = closeErr
}
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
}
data, err = io.ReadAll(resp.Body)
return data, err
}
// loadSchemaFromFile loads a JSON schema from a file path.
func loadSchemaFromFile(workingDir string, dagLocation string, filePath string) ([]byte, error) {
// Try to resolve the schema file path in the following order:
// 1) Current working directory (default ResolvePath behavior)
// 2) DAG's workingDir value
// 3) Directory of the DAG file (where it was loaded from)
var tried []string
// Attempts a candidate by joining base and filePath (if base provided),
// resolving env/tilde + absolute path, checking existence, and reading.
tryCandidate := func(label, base string) ([]byte, string, error) {
var candidate string
if strings.TrimSpace(base) == "" {
candidate = filePath
} else {
candidate = filepath.Join(base, filePath)
}
resolved, err := fileutil.ResolvePath(candidate)
if err != nil {
tried = append(tried, fmt.Sprintf("%s: resolve error: %v", label, err))
return nil, "", err
}
if !fileutil.FileExists(resolved) {
tried = append(tried, fmt.Sprintf("%s: %s", label, resolved))
return nil, resolved, os.ErrNotExist
}
data, err := os.ReadFile(resolved) // #nosec G304 - validated path
if err != nil {
tried = append(tried, fmt.Sprintf("%s: %s (read error: %v)", label, resolved, err))
return nil, resolved, err
}
return data, resolved, nil
}
// 1) As provided (CWD/env/tilde expansion handled by ResolvePath)
if data, _, err := tryCandidate("cwd", ""); err == nil {
return data, nil
}
// 2) From DAG's workingDir value if present
if wd := strings.TrimSpace(workingDir); wd != "" {
if data, _, err := tryCandidate(fmt.Sprintf("workingDir(%s)", wd), wd); err == nil {
return data, nil
}
}
// 3) From the directory of the DAG file used to build
if dagLocation != "" {
base := filepath.Dir(dagLocation)
if data, _, err := tryCandidate(fmt.Sprintf("dagDir(%s)", base), base); err == nil {
return data, nil
}
}
if len(tried) == 0 {
return nil, fmt.Errorf("failed to resolve schema file path: %s (no candidates)", filePath)
}
return nil, fmt.Errorf("schema file not found for %q; tried %s", filePath, strings.Join(tried, ", "))
}
// extractSchemaReference extracts the schema reference from a params map.
// Returns the schema reference as a string if present and valid, empty string otherwise.
func extractSchemaReference(params any) string {
paramsMap, ok := params.(map[string]any)
if !ok {
return ""
}
schemaRef, hasSchema := paramsMap["schema"]
if !hasSchema {
return ""
}
schemaRefStr, ok := schemaRef.(string)
if !ok {
return ""
}
return schemaRefStr
}

View File

@ -0,0 +1,833 @@
package spec
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractSchemaReference(t *testing.T) {
t.Parallel()
tests := []struct {
name string
params any
expected string
}{
{
name: "Nil",
params: nil,
expected: "",
},
{
name: "NotAMap",
params: "string value",
expected: "",
},
{
name: "EmptyMap",
params: map[string]any{},
expected: "",
},
{
name: "NoSchemaKey",
params: map[string]any{"values": map[string]any{"foo": "bar"}},
expected: "",
},
{
name: "SchemaKeyNotString",
params: map[string]any{"schema": 123},
expected: "",
},
{
name: "SchemaKeyIsMap",
params: map[string]any{"schema": map[string]any{"type": "object"}},
expected: "",
},
{
name: "ValidSchemaReference",
params: map[string]any{"schema": "schema.json"},
expected: "schema.json",
},
{
name: "ValidSchemaReferenceWithValues",
params: map[string]any{"schema": "./schemas/params.json", "values": map[string]any{"foo": "bar"}},
expected: "./schemas/params.json",
},
{
name: "HTTPSchemaReference",
params: map[string]any{"schema": "https://example.com/schema.json"},
expected: "https://example.com/schema.json",
},
{
name: "EmptySchemaString",
params: map[string]any{"schema": ""},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := extractSchemaReference(tt.params)
assert.Equal(t, tt.expected, result)
})
}
}
func TestLoadSchemaFromURL(t *testing.T) {
t.Parallel()
t.Run("SuccessfulLoad", func(t *testing.T) {
t.Parallel()
schemaContent := `{"type": "object", "properties": {"foo": {"type": "string"}}}`
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(schemaContent))
}))
defer server.Close()
data, err := loadSchemaFromURL(server.URL + "/schema.json")
require.NoError(t, err)
assert.Equal(t, schemaContent, string(data))
})
t.Run("HTTPError", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
_, err := loadSchemaFromURL(server.URL + "/missing.json")
require.Error(t, err)
assert.Contains(t, err.Error(), "404")
})
t.Run("ServerError", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
_, err := loadSchemaFromURL(server.URL + "/schema.json")
require.Error(t, err)
assert.Contains(t, err.Error(), "500")
})
t.Run("InvalidURL", func(t *testing.T) {
t.Parallel()
_, err := loadSchemaFromURL("://invalid-url")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid")
})
t.Run("UnsupportedScheme", func(t *testing.T) {
t.Parallel()
_, err := loadSchemaFromURL("ftp://example.com/schema.json")
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported URL scheme")
})
t.Run("ConnectionRefused", func(t *testing.T) {
t.Parallel()
// Use a port that's unlikely to be in use
_, err := loadSchemaFromURL("http://127.0.0.1:59999/schema.json")
require.Error(t, err)
})
}
func TestLoadSchemaFromFile(t *testing.T) {
t.Parallel()
schemaContent := `{"type": "object"}`
t.Run("AbsolutePath", func(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
schemaPath := filepath.Join(tmpDir, "schema.json")
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
data, err := loadSchemaFromFile("", "", schemaPath)
require.NoError(t, err)
assert.Equal(t, schemaContent, string(data))
})
t.Run("FromWorkingDir", func(t *testing.T) {
t.Parallel()
workingDir := t.TempDir()
schemaPath := filepath.Join(workingDir, "schema.json")
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
data, err := loadSchemaFromFile(workingDir, "", "schema.json")
require.NoError(t, err)
assert.Equal(t, schemaContent, string(data))
})
t.Run("FromDAGDirectory", func(t *testing.T) {
t.Parallel()
dagDir := t.TempDir()
schemaPath := filepath.Join(dagDir, "schema.json")
dagPath := filepath.Join(dagDir, "dag.yaml")
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
data, err := loadSchemaFromFile("", dagPath, "schema.json")
require.NoError(t, err)
assert.Equal(t, schemaContent, string(data))
})
t.Run("WorkingDirTakesPrecedenceOverDAGDir", func(t *testing.T) {
t.Parallel()
workingDir := t.TempDir()
dagDir := t.TempDir()
wdSchema := `{"type": "object", "title": "working-dir"}`
dagSchema := `{"type": "object", "title": "dag-dir"}`
require.NoError(t, os.WriteFile(filepath.Join(workingDir, "schema.json"), []byte(wdSchema), 0600))
require.NoError(t, os.WriteFile(filepath.Join(dagDir, "schema.json"), []byte(dagSchema), 0600))
dagPath := filepath.Join(dagDir, "dag.yaml")
data, err := loadSchemaFromFile(workingDir, dagPath, "schema.json")
require.NoError(t, err)
assert.Equal(t, wdSchema, string(data))
})
t.Run("FileNotFound", func(t *testing.T) {
t.Parallel()
_, err := loadSchemaFromFile("", "", "nonexistent-schema.json")
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
})
t.Run("FileNotFoundWithWorkingDir", func(t *testing.T) {
t.Parallel()
workingDir := t.TempDir()
_, err := loadSchemaFromFile(workingDir, "", "nonexistent.json")
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
})
t.Run("FileNotFoundWithDAGDir", func(t *testing.T) {
t.Parallel()
dagDir := t.TempDir()
dagPath := filepath.Join(dagDir, "dag.yaml")
_, err := loadSchemaFromFile("", dagPath, "nonexistent.json")
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
})
t.Run("SubdirectoryPath", func(t *testing.T) {
t.Parallel()
workingDir := t.TempDir()
schemasDir := filepath.Join(workingDir, "schemas")
require.NoError(t, os.MkdirAll(schemasDir, 0755))
schemaPath := filepath.Join(schemasDir, "params.json")
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
data, err := loadSchemaFromFile(workingDir, "", "schemas/params.json")
require.NoError(t, err)
assert.Equal(t, schemaContent, string(data))
})
t.Run("EmptyWorkingDirAndDAGLocation", func(t *testing.T) {
t.Parallel()
_, err := loadSchemaFromFile("", "", "schema.json")
require.Error(t, err)
})
t.Run("WhitespaceOnlyWorkingDir", func(t *testing.T) {
t.Parallel()
dagDir := t.TempDir()
schemaPath := filepath.Join(dagDir, "schema.json")
dagPath := filepath.Join(dagDir, "dag.yaml")
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
// Whitespace-only workingDir should be skipped, fall back to dagDir
data, err := loadSchemaFromFile(" ", dagPath, "schema.json")
require.NoError(t, err)
assert.Equal(t, schemaContent, string(data))
})
}
func TestGetSchemaFromRef(t *testing.T) {
t.Parallel()
validSchemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"name": {"type": "string"}
}
}`
t.Run("LocalFileSchema", func(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
schemaPath := filepath.Join(tmpDir, "schema.json")
require.NoError(t, os.WriteFile(schemaPath, []byte(validSchemaContent), 0600))
resolved, err := getSchemaFromRef("", "", schemaPath)
require.NoError(t, err)
assert.NotNil(t, resolved)
})
t.Run("RemoteURLSchema", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(validSchemaContent))
}))
defer server.Close()
resolved, err := getSchemaFromRef("", "", server.URL+"/schema.json")
require.NoError(t, err)
assert.NotNil(t, resolved)
})
t.Run("HTTPSSchemaReference", func(t *testing.T) {
t.Parallel()
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(validSchemaContent))
}))
defer server.Close()
// This will fail due to self-signed cert, but tests the https:// detection
_, err := getSchemaFromRef("", "", server.URL+"/schema.json")
// We expect an error due to certificate verification
require.Error(t, err)
})
t.Run("InvalidJSONSchema", func(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
schemaPath := filepath.Join(tmpDir, "invalid.json")
require.NoError(t, os.WriteFile(schemaPath, []byte("not valid json"), 0600))
_, err := getSchemaFromRef("", "", schemaPath)
require.Error(t, err)
assert.Contains(t, err.Error(), "parse schema JSON")
})
t.Run("SchemaFileNotFound", func(t *testing.T) {
t.Parallel()
_, err := getSchemaFromRef("", "", "nonexistent.json")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to load schema")
})
t.Run("URLNotFound", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
_, err := getSchemaFromRef("", "", server.URL+"/missing.json")
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to load schema")
})
}
func TestResolveSchemaFromParams(t *testing.T) {
t.Parallel()
validSchemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {"type": "integer", "default": 10}
}
}`
t.Run("NoSchemaReference", func(t *testing.T) {
t.Parallel()
resolved, err := resolveSchemaFromParams(nil, "", "")
require.NoError(t, err)
assert.Nil(t, resolved)
})
t.Run("ParamsNotMap", func(t *testing.T) {
t.Parallel()
resolved, err := resolveSchemaFromParams("string params", "", "")
require.NoError(t, err)
assert.Nil(t, resolved)
})
t.Run("ParamsWithoutSchema", func(t *testing.T) {
t.Parallel()
params := map[string]any{
"values": map[string]any{"foo": "bar"},
}
resolved, err := resolveSchemaFromParams(params, "", "")
require.NoError(t, err)
assert.Nil(t, resolved)
})
t.Run("ParamsWithValidSchema", func(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
schemaPath := filepath.Join(tmpDir, "schema.json")
require.NoError(t, os.WriteFile(schemaPath, []byte(validSchemaContent), 0600))
params := map[string]any{
"schema": schemaPath,
"values": map[string]any{"batch_size": 20},
}
resolved, err := resolveSchemaFromParams(params, "", "")
require.NoError(t, err)
assert.NotNil(t, resolved)
})
t.Run("ParamsWithInvalidSchemaPath", func(t *testing.T) {
t.Parallel()
params := map[string]any{
"schema": "nonexistent.json",
}
_, err := resolveSchemaFromParams(params, "", "")
require.Error(t, err)
})
t.Run("ParamsWithRemoteSchema", func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(validSchemaContent))
}))
defer server.Close()
params := map[string]any{
"schema": server.URL + "/schema.json",
}
resolved, err := resolveSchemaFromParams(params, "", "")
require.NoError(t, err)
assert.NotNil(t, resolved)
})
t.Run("UsesWorkingDirForResolution", func(t *testing.T) {
t.Parallel()
workingDir := t.TempDir()
schemaPath := filepath.Join(workingDir, "schema.json")
require.NoError(t, os.WriteFile(schemaPath, []byte(validSchemaContent), 0600))
params := map[string]any{
"schema": "schema.json",
}
resolved, err := resolveSchemaFromParams(params, workingDir, "")
require.NoError(t, err)
assert.NotNil(t, resolved)
})
t.Run("UsesDAGLocationForResolution", func(t *testing.T) {
t.Parallel()
dagDir := t.TempDir()
schemaPath := filepath.Join(dagDir, "schema.json")
dagPath := filepath.Join(dagDir, "dag.yaml")
require.NoError(t, os.WriteFile(schemaPath, []byte(validSchemaContent), 0600))
params := map[string]any{
"schema": "schema.json",
}
resolved, err := resolveSchemaFromParams(params, "", dagPath)
require.NoError(t, err)
assert.NotNil(t, resolved)
})
}
func TestBuildParamsWithLocalSchemaReference(t *testing.T) {
t.Parallel()
schemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {
"type": "integer",
"default": 10,
"minimum": 1
},
"environment": {
"type": "string",
"default": "dev",
"enum": ["dev", "staging", "prod"]
}
}
}`
tmpFile, err := os.CreateTemp("", "test-schema-*.json")
require.NoError(t, err)
defer func() { _ = os.Remove(tmpFile.Name()) }()
_, err = tmpFile.WriteString(schemaContent)
require.NoError(t, err)
require.NoError(t, tmpFile.Close())
data := []byte(fmt.Sprintf(`
params:
schema: "%s"
values:
batch_size: 25
environment: "staging"
`, tmpFile.Name()))
dag, err := LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Len(t, dag.Params, 2)
require.Contains(t, dag.Params, "batch_size=25")
require.Contains(t, dag.Params, "environment=staging")
}
func TestBuildParamsWithRemoteSchemaReference(t *testing.T) {
t.Parallel()
mux := http.NewServeMux()
mux.HandleFunc("/schemas/dag-params.json", func(w http.ResponseWriter, _ *http.Request) {
schemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {
"type": "integer",
"default": 10,
"minimum": 1
},
"environment": {
"type": "string",
"default": "dev",
"enum": ["dev", "staging", "prod"]
}
}
}`
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(schemaContent))
})
server := httptest.NewServer(mux)
defer server.Close()
data := []byte(fmt.Sprintf(`
params:
schema: "%s/schemas/dag-params.json"
values:
batch_size: 50
environment: "prod"
`, server.URL))
dag, err := LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Len(t, dag.Params, 2)
require.Contains(t, dag.Params, "batch_size=50")
require.Contains(t, dag.Params, "environment=prod")
}
func TestBuildParamsSchemaResolution(t *testing.T) {
t.Run("FromWorkingDir", func(t *testing.T) {
schemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {"type": "integer", "default": 42}
}
}`
wd := t.TempDir()
wdSchema := filepath.Join(wd, "schema.json")
require.NoError(t, os.WriteFile(wdSchema, []byte(schemaContent), 0600))
origWD, err := os.Getwd()
require.NoError(t, err)
t.Cleanup(func() {
if err := os.Chdir(origWD); err != nil {
t.Fatalf("failed to restore working directory: %v", err)
}
})
data := []byte(fmt.Sprintf(`
workingDir: %s
params:
schema: "schema.json"
values:
environment: "dev"
`, wd))
dag, err := LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Contains(t, dag.Params, "batch_size=42")
})
t.Run("FromDAGDir", func(t *testing.T) {
schemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {"type": "integer", "default": 7}
}
}`
dir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(dir, "schema.json"), []byte(schemaContent), 0600))
dagYaml := []byte(`
params:
schema: "schema.json"
values:
environment: "staging"
`)
dagPath := filepath.Join(dir, "dag.yaml")
require.NoError(t, os.WriteFile(dagPath, dagYaml, 0600))
dag, err := Load(context.Background(), dagPath)
require.NoError(t, err)
require.Contains(t, dag.Params, "batch_size=7")
})
t.Run("PrefersCWDOverWorkingDir", func(t *testing.T) {
cwdSchemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {"type": "integer", "default": 99}
}
}`
wdSchemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {"type": "integer", "default": 11}
}
}`
cwd := t.TempDir()
wd := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(cwd, "schema.json"), []byte(cwdSchemaContent), 0600))
require.NoError(t, os.WriteFile(filepath.Join(wd, "schema.json"), []byte(wdSchemaContent), 0600))
orig, err := os.Getwd()
require.NoError(t, err)
require.NoError(t, os.Chdir(cwd))
defer func() { _ = os.Chdir(orig) }()
data := []byte(fmt.Sprintf(`
workingDir: %s
params:
schema: "schema.json"
values:
environment: "dev"
`, wd))
dag, err := LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Contains(t, dag.Params, "batch_size=99")
})
}
func TestBuildParamsSchemaValidation(t *testing.T) {
t.Parallel()
t.Run("SkipSchemaValidationFlag", func(t *testing.T) {
t.Parallel()
data := []byte(`
params:
schema: "missing-schema.json"
values:
foo: "bar"
`)
_, err := LoadYAML(context.Background(), data)
require.Error(t, err)
dag, err := LoadYAMLWithOpts(context.Background(), data, BuildOpts{
Flags: BuildFlagSkipSchemaValidation,
})
require.NoError(t, err)
require.Len(t, dag.Params, 1)
require.Contains(t, dag.Params, "foo=bar")
})
t.Run("OverrideValidationFails", func(t *testing.T) {
schemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {
"type": "integer",
"default": 10,
"minimum": 1,
"maximum": 50
},
"environment": {
"type": "string",
"default": "dev",
"enum": ["dev", "staging", "prod"]
}
}
}`
tmpFile, err := os.CreateTemp("", "test-schema-validation-*.json")
require.NoError(t, err)
defer func() { _ = os.Remove(tmpFile.Name()) }()
_, err = tmpFile.WriteString(schemaContent)
require.NoError(t, err)
require.NoError(t, tmpFile.Close())
data := []byte(fmt.Sprintf(`
params:
schema: "%s"
`, tmpFile.Name()))
cliParams := "batch_size=100 environment=prod"
_, err = LoadYAML(context.Background(), data, WithParams(cliParams))
require.Error(t, err)
require.Contains(t, err.Error(), "parameter validation failed")
require.Contains(t, err.Error(), "maximum: 100/1 is greater than 50")
})
t.Run("DefaultsApplied", func(t *testing.T) {
schemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {
"type": "integer",
"default": 25,
"minimum": 1,
"maximum": 100
},
"environment": {
"type": "string",
"default": "development",
"enum": ["development", "staging", "production"]
},
"debug": {
"type": "boolean",
"default": true
}
}
}`
tmpFile, err := os.CreateTemp("", "test-schema-defaults-*.json")
require.NoError(t, err)
defer func() { _ = os.Remove(tmpFile.Name()) }()
_, err = tmpFile.WriteString(schemaContent)
require.NoError(t, err)
require.NoError(t, tmpFile.Close())
data := []byte(fmt.Sprintf(`
params:
schema: "%s"
values:
batch_size: 75
`, tmpFile.Name()))
dag, err := LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Len(t, dag.Params, 3)
require.Contains(t, dag.Params, "batch_size=75")
require.Contains(t, dag.Params, "environment=development")
require.Contains(t, dag.Params, "debug=true")
})
t.Run("DefaultsPreserveExistingValues", func(t *testing.T) {
schemaContent := `{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"batch_size": {
"type": "integer",
"default": 25,
"minimum": 1,
"maximum": 100
},
"environment": {
"type": "string",
"default": "development",
"enum": ["development", "staging", "production"]
},
"debug": {
"type": "boolean",
"default": true
},
"timeout": {
"type": "integer",
"default": 300
}
}
}`
tmpFile, err := os.CreateTemp("", "test-schema-preserve-*.json")
require.NoError(t, err)
defer func() { _ = os.Remove(tmpFile.Name()) }()
_, err = tmpFile.WriteString(schemaContent)
require.NoError(t, err)
require.NoError(t, tmpFile.Close())
data := []byte(fmt.Sprintf(`
params:
schema: "%s"
values:
batch_size: 50
environment: "production"
debug: false
timeout: 600
`, tmpFile.Name()))
dag, err := LoadYAML(context.Background(), data)
require.NoError(t, err)
require.Len(t, dag.Params, 4)
require.Contains(t, dag.Params, "batch_size=50")
require.Contains(t, dag.Params, "environment=production")
require.Contains(t, dag.Params, "debug=false")
require.Contains(t, dag.Params, "timeout=600")
})
}

1214
internal/core/spec/step.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,207 @@
package types
import (
"fmt"
"strconv"
"strings"
"github.com/goccy/go-yaml"
)
// ContinueOnValue represents a continue-on configuration that can be specified as:
// - A string shorthand: "skipped" or "failed"
// - A detailed map with configuration options
//
// YAML examples:
//
// continueOn: skipped
// continueOn: failed
// continueOn:
// skipped: true
// failed: true
// exitCode: [0, 1]
// output: ["pattern1", "pattern2"]
// markSuccess: true
type ContinueOnValue struct {
raw any // Original value for error reporting
isSet bool // Whether the field was set in YAML
skipped bool // Continue on skipped
failed bool // Continue on failed
exitCode []int // Specific exit codes to continue on
output []string // Output patterns to match
markSuccess bool // Mark step as success when condition is met
}
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
func (c *ContinueOnValue) UnmarshalYAML(data []byte) error {
c.isSet = true
var raw any
if err := yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("continueOn unmarshal error: %w", err)
}
c.raw = raw
switch v := raw.(type) {
case string:
// String shorthand: "skipped" or "failed"
switch strings.ToLower(strings.TrimSpace(v)) {
case "skipped":
c.skipped = true
case "failed":
c.failed = true
default:
return fmt.Errorf("continueOn: expected 'skipped' or 'failed', got %q", v)
}
return nil
case map[string]any:
// Detailed configuration
return c.parseMap(v)
case nil:
c.isSet = false
return nil
default:
return fmt.Errorf("continueOn must be string or map, got %T", v)
}
}
func (c *ContinueOnValue) parseMap(m map[string]any) error {
for key, v := range m {
switch key {
case "skipped":
if b, ok := v.(bool); ok {
c.skipped = b
} else {
return fmt.Errorf("continueOn.skipped: expected boolean, got %T", v)
}
case "failed", "failure":
if b, ok := v.(bool); ok {
c.failed = b
} else {
return fmt.Errorf("continueOn.%s: expected boolean, got %T", key, v)
}
case "exitCode":
codes, err := parseIntArray(v)
if err != nil {
return fmt.Errorf("continueOn.exitCode: %w", err)
}
c.exitCode = codes
case "output":
outputs, err := parseStringArray(v)
if err != nil {
return fmt.Errorf("continueOn.output: %w", err)
}
c.output = outputs
case "markSuccess":
if b, ok := v.(bool); ok {
c.markSuccess = b
} else {
return fmt.Errorf("continueOn.markSuccess: expected boolean, got %T", v)
}
default:
return fmt.Errorf("continueOn: unknown key %q", key)
}
}
return nil
}
func parseStringArray(v any) ([]string, error) {
switch val := v.(type) {
case nil:
return nil, nil
case string:
if val == "" {
return nil, nil
}
return []string{val}, nil
case []any:
var result []string
for i, item := range val {
s, ok := item.(string)
if !ok {
return nil, fmt.Errorf("[%d]: expected string, got %T", i, item)
}
result = append(result, s)
}
return result, nil
default:
return nil, fmt.Errorf("expected string or array of strings, got %T", v)
}
}
func parseIntArray(v any) ([]int, error) {
switch val := v.(type) {
case []any:
var result []int
for i, item := range val {
n, err := parseIntValue(item)
if err != nil {
return nil, fmt.Errorf("[%d]: %w", i, err)
}
result = append(result, n)
}
return result, nil
case int:
return []int{val}, nil
case int64:
return []int{int(val)}, nil
case float64:
return []int{int(val)}, nil
case uint64:
// Exit codes are small numbers, overflow won't happen in practice
return []int{int(val)}, nil //nolint:gosec // Exit codes are small numbers
case string:
n, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("cannot parse %q as int: %w", val, err)
}
return []int{n}, nil
default:
return nil, fmt.Errorf("expected int or array of ints, got %T", v)
}
}
func parseIntValue(item any) (int, error) {
switch v := item.(type) {
case int:
return v, nil
case int64:
return int(v), nil
case float64:
return int(v), nil
case uint64:
return int(v), nil //nolint:gosec // Exit codes are small numbers
case string:
n, err := strconv.Atoi(v)
if err != nil {
return 0, fmt.Errorf("cannot parse %q as int: %w", v, err)
}
return n, nil
default:
return 0, fmt.Errorf("expected int, got %T", item)
}
}
// IsZero returns true if continueOn was not set in YAML.
func (c ContinueOnValue) IsZero() bool { return !c.isSet }
// Value returns the original raw value for error reporting.
func (c ContinueOnValue) Value() any { return c.raw }
// Skipped returns true if should continue on skipped.
func (c ContinueOnValue) Skipped() bool { return c.skipped }
// Failed returns true if should continue on failed.
func (c ContinueOnValue) Failed() bool { return c.failed }
// ExitCode returns exit codes to continue on.
func (c ContinueOnValue) ExitCode() []int { return c.exitCode }
// Output returns output patterns to match.
func (c ContinueOnValue) Output() []string { return c.output }
// MarkSuccess returns true if step should be marked as success when condition is met.
func (c ContinueOnValue) MarkSuccess() bool { return c.markSuccess }

View File

@ -0,0 +1,413 @@
package types_test
import (
"testing"
"github.com/dagu-org/dagu/internal/core/spec/types"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestContinueOnValue_UnmarshalYAML(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantSkipped bool
wantFailed bool
wantExitCode []int
wantOutput []string
wantMarkSuccess bool
checkIsZero bool
checkNotZero bool
}{
{
name: "StringSkipped",
input: "skipped",
wantSkipped: true,
wantFailed: false,
checkNotZero: true,
},
{
name: "StringFailed",
input: "failed",
wantSkipped: false,
wantFailed: true,
},
{
name: "StringCaseInsensitiveSKIPPED",
input: "SKIPPED",
wantSkipped: true,
},
{
name: "StringCaseInsensitiveFailed",
input: "Failed",
wantFailed: true,
},
{
name: "StringWithWhitespace",
input: `" skipped "`,
wantSkipped: true,
},
{
name: "MapFormSkippedOnly",
input: "skipped: true",
wantSkipped: true,
wantFailed: false,
},
{
name: "MapFormFailedOnly",
input: "failed: true",
wantSkipped: false,
wantFailed: true,
},
{
name: "MapFormBoth",
input: `
skipped: true
failed: true
`,
wantSkipped: true,
wantFailed: true,
},
{
name: "MapWithExitCodesArray",
input: "exitCode: [0, 1, 2]",
wantExitCode: []int{0, 1, 2},
},
{
name: "MapWithSingleExitCode",
input: "exitCode: 1",
wantExitCode: []int{1},
},
{
name: "MapWithOutputPattern",
input: `output: "success|warning"`,
wantOutput: []string{"success|warning"},
},
{
name: "MapWithAllFields",
input: `
skipped: true
failed: true
exitCode: [0, 1]
output: "OK"
markSuccess: true
`,
wantSkipped: true,
wantFailed: true,
wantExitCode: []int{0, 1},
wantOutput: []string{"OK"},
wantMarkSuccess: true,
},
{
name: "InvalidStringValue",
input: "invalid",
wantErr: true,
errContains: "expected 'skipped' or 'failed'",
},
{
name: "InvalidMapKey",
input: "unknown: true",
wantErr: true,
errContains: "unknown key",
},
{
name: "InvalidSkippedType",
input: `skipped: "yes"`,
wantErr: true,
errContains: "expected bool",
},
{
name: "InvalidExitCodeType",
input: `exitCode: "not a number"`,
wantErr: true,
errContains: "cannot parse",
},
{
name: "InvalidTypeArray",
input: "[1, 2, 3]",
wantErr: true,
errContains: "must be string or map",
},
{
name: "OutputAsStringArray",
input: `output: ["success", "warning", "info"]`,
wantOutput: []string{"success", "warning", "info"},
},
{
name: "ExitCodeAsInt64",
input: "exitCode: 255",
wantExitCode: []int{255},
},
{
name: "ExitCodeAsString",
input: `exitCode: "42"`,
wantExitCode: []int{42},
},
{
name: "ExitCodeArrayWithMixedTypes",
input: `exitCode: [0, "1", 2]`,
wantExitCode: []int{0, 1, 2},
},
{
name: "OutputInvalidTypeInArray",
input: `output: [123, true]`,
wantErr: true,
errContains: "expected string",
},
{
name: "ExitCodeInvalidString",
input: `exitCode: "not-a-number"`,
wantErr: true,
errContains: "cannot parse",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var c types.ContinueOnValue
err := yaml.Unmarshal([]byte(tt.input), &c)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.checkIsZero {
assert.True(t, c.IsZero())
return
}
if tt.checkNotZero {
assert.False(t, c.IsZero())
}
if tt.wantSkipped {
assert.True(t, c.Skipped())
}
if tt.wantFailed {
assert.True(t, c.Failed())
}
if tt.wantExitCode != nil {
assert.Equal(t, tt.wantExitCode, c.ExitCode())
}
if tt.wantOutput != nil {
assert.Equal(t, tt.wantOutput, c.Output())
}
if tt.wantMarkSuccess {
assert.True(t, c.MarkSuccess())
}
})
}
t.Run("ZeroValue", func(t *testing.T) {
t.Parallel()
var c types.ContinueOnValue
assert.True(t, c.IsZero())
assert.False(t, c.Skipped())
assert.False(t, c.Failed())
})
}
func TestContinueOnValue_InStruct(t *testing.T) {
t.Parallel()
type StepConfig struct {
Name string `yaml:"name"`
ContinueOn types.ContinueOnValue `yaml:"continueOn"`
}
tests := []struct {
name string
input string
wantSkipped bool
wantFailed bool
wantExitCode []int
wantIsZero bool
}{
{
name: "ContinueOnAsString",
input: `
name: my-step
continueOn: skipped
`,
wantSkipped: true,
},
{
name: "ContinueOnAsMap",
input: `
name: my-step
continueOn:
failed: true
exitCode: [0, 1]
`,
wantFailed: true,
wantExitCode: []int{0, 1},
},
{
name: "ContinueOnNotSet",
input: "name: my-step",
wantIsZero: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var cfg StepConfig
err := yaml.Unmarshal([]byte(tt.input), &cfg)
require.NoError(t, err)
if tt.wantIsZero {
assert.True(t, cfg.ContinueOn.IsZero())
return
}
if tt.wantSkipped {
assert.True(t, cfg.ContinueOn.Skipped())
}
if tt.wantFailed {
assert.True(t, cfg.ContinueOn.Failed())
}
if tt.wantExitCode != nil {
assert.Equal(t, tt.wantExitCode, cfg.ContinueOn.ExitCode())
}
})
}
}
func TestContinueOnValue_Value(t *testing.T) {
t.Parallel()
t.Run("ValueReturnsRawString", func(t *testing.T) {
t.Parallel()
var c types.ContinueOnValue
err := yaml.Unmarshal([]byte("skipped"), &c)
require.NoError(t, err)
assert.Equal(t, "skipped", c.Value())
})
t.Run("ValueReturnsMap", func(t *testing.T) {
t.Parallel()
var c types.ContinueOnValue
err := yaml.Unmarshal([]byte("failed: true"), &c)
require.NoError(t, err)
val, ok := c.Value().(map[string]any)
require.True(t, ok)
assert.Equal(t, true, val["failed"])
})
}
func TestContinueOnValue_EdgeCases(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
checkIsZero bool
wantFailed bool
wantExitCode []int
wantOutputNil bool
}{
{
name: "NullValue",
input: "null",
checkIsZero: true,
},
{
name: "FailureKeyAlias",
input: "failure: true",
wantFailed: true,
},
{
name: "ExitCodeAsFloat",
input: "exitCode: 1.0",
wantExitCode: []int{1},
},
{
name: "ExitCodeArrayWithFloat",
input: "exitCode: [1.0, 2.0]",
wantExitCode: []int{1, 2},
},
{
name: "InvalidExitCodeTypeInArray",
input: "exitCode: [true]",
wantErr: true,
errContains: "expected int",
},
{
name: "InvalidExitCodeTypeNotIntOrArray",
input: "exitCode: {key: value}",
wantErr: true,
errContains: "expected int or array",
},
{
name: "OutputAsNil",
input: "output: null",
wantOutputNil: true,
},
{
name: "OutputAsEmptyString",
input: `output: ""`,
wantOutputNil: true,
},
{
name: "OutputInvalidTypeNotStringOrArray",
input: "output: 123",
wantErr: true,
errContains: "expected string or array",
},
{
name: "MarkSuccessInvalidType",
input: `markSuccess: "yes"`,
wantErr: true,
errContains: "expected bool",
},
{
name: "FailedInvalidType",
input: `failed: "yes"`,
wantErr: true,
errContains: "expected bool",
},
{
name: "ExitCodeInvalidStringInArray",
input: `exitCode: ["not-a-number"]`,
wantErr: true,
errContains: "cannot parse",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var c types.ContinueOnValue
err := yaml.Unmarshal([]byte(tt.input), &c)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.checkIsZero {
assert.True(t, c.IsZero())
}
if tt.wantFailed {
assert.True(t, c.Failed())
}
if tt.wantExitCode != nil {
assert.Equal(t, tt.wantExitCode, c.ExitCode())
}
if tt.wantOutputNil {
assert.Nil(t, c.Output())
}
})
}
}

View File

@ -0,0 +1,15 @@
// Package types provides typed union types for YAML fields that accept multiple formats.
//
// These types provide type-safe unmarshaling with early validation while maintaining
// full backward compatibility with existing YAML files.
//
// Design principles:
// 1. Each type captures all valid YAML representations for a field
// 2. Validation happens at unmarshal time, not at build time
// 3. Accessor methods provide type-safe access to the parsed values
// 4. IsZero() indicates whether the field was set in YAML
//
// The types in this package are designed to be used as drop-in replacements
// for `any` typed fields in the spec.definition struct, enabling gradual
// migration while maintaining backward compatibility.
package types

View File

@ -0,0 +1,114 @@
package types
import (
"fmt"
"strings"
"github.com/goccy/go-yaml"
)
// EnvValue represents environment variable configuration that can be specified as:
// - A map of key-value pairs
// - An array of maps (for ordered definitions)
// - An array of "KEY=value" strings
// - A mix of maps and strings in an array
//
// YAML examples:
//
// env:
// KEY1: value1
// KEY2: value2
//
// env:
// - KEY1: value1
// - KEY2: value2
//
// env:
// - KEY1=value1
// - KEY2=value2
type EnvValue struct {
raw any // Original value for error reporting
isSet bool // Whether the field was set in YAML
entries []EnvEntry // Parsed entries in order
}
// EnvEntry represents a single environment variable entry.
type EnvEntry struct {
Key string
Value string
}
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
func (e *EnvValue) UnmarshalYAML(data []byte) error {
e.isSet = true
var raw any
if err := yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("env unmarshal error: %w", err)
}
e.raw = raw
switch v := raw.(type) {
case map[string]any:
// Map of key-value pairs
return e.parseMap(v)
case []any:
// Array of maps or strings
return e.parseArray(v)
case nil:
e.isSet = false
return nil
default:
return fmt.Errorf("env must be map or array, got %T", v)
}
}
func (e *EnvValue) parseMap(m map[string]any) error {
for key, v := range m {
value := stringifyValue(v)
e.entries = append(e.entries, EnvEntry{Key: key, Value: value})
}
return nil
}
func (e *EnvValue) parseArray(arr []any) error {
for i, item := range arr {
switch v := item.(type) {
case map[string]any:
for key, val := range v {
value := stringifyValue(val)
e.entries = append(e.entries, EnvEntry{Key: key, Value: value})
}
case string:
key, val, found := strings.Cut(v, "=")
if !found {
return fmt.Errorf("env[%d]: invalid format %q (expected KEY=value)", i, v)
}
e.entries = append(e.entries, EnvEntry{Key: key, Value: val})
default:
return fmt.Errorf("env[%d]: expected map or string, got %T", i, item)
}
}
return nil
}
func stringifyValue(v any) string {
switch val := v.(type) {
case string:
return val
default:
return fmt.Sprintf("%v", val)
}
}
// IsZero returns true if env was not set in YAML.
func (e EnvValue) IsZero() bool { return !e.isSet }
// Value returns the original raw value for error reporting.
func (e EnvValue) Value() any { return e.raw }
// Entries returns the parsed environment entries in order.
func (e EnvValue) Entries() []EnvEntry { return e.entries }

View File

@ -0,0 +1,320 @@
package types_test
import (
"testing"
"github.com/dagu-org/dagu/internal/core/spec/types"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEnvValue_UnmarshalYAML(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantEntryCount int
wantEntries map[string]string
checkIsZero bool
checkNotZero bool
checkEmpty bool
}{
{
name: "MapForm",
input: `
KEY1: value1
KEY2: value2
`,
wantEntryCount: 2,
wantEntries: map[string]string{"KEY1": "value1", "KEY2": "value2"},
},
{
name: "ArrayOfMapsPreservesOrder",
input: `
- KEY1: value1
- KEY2: value2
- KEY3: value3
`,
wantEntryCount: 3,
wantEntries: map[string]string{"KEY1": "value1", "KEY2": "value2", "KEY3": "value3"},
},
{
name: "ArrayOfStrings",
input: `
- KEY1=value1
- KEY2=value2
`,
wantEntryCount: 2,
wantEntries: map[string]string{"KEY1": "value1"},
},
{
name: "MixedArrayMapsAndStrings",
input: `
- KEY1: value1
- KEY2=value2
- KEY3: value3
`,
wantEntryCount: 3,
},
{
name: "NumericValuesStringified",
input: `
PORT: 8080
ENABLED: true
RATIO: 0.5
`,
wantEntryCount: 3,
wantEntries: map[string]string{"PORT": "8080", "ENABLED": "true", "RATIO": "0.5"},
},
{
name: "ValueWithEqualsSign",
input: `
- CONNECTION_STRING=host=localhost;port=5432
`,
wantEntryCount: 1,
wantEntries: map[string]string{"CONNECTION_STRING": "host=localhost;port=5432"},
},
{
name: "InvalidStringFormatNoEquals",
input: `["invalid_no_equals"]`,
wantErr: true,
errContains: "expected KEY=value",
},
{
name: "InvalidTypeScalarString",
input: `"just a string"`,
wantErr: true,
errContains: "must be map or array",
},
{
name: "EmptyMap",
input: "{}",
checkNotZero: true,
checkEmpty: true,
},
{
name: "EmptyArray",
input: "[]",
checkNotZero: true,
checkEmpty: true,
},
{
name: "EnvironmentVariableReference",
input: `
- PATH: ${HOME}/bin
- DERIVED: ${OTHER_VAR}
`,
wantEntryCount: 2,
wantEntries: map[string]string{"PATH": "${HOME}/bin", "DERIVED": "${OTHER_VAR}"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var e types.EnvValue
err := yaml.Unmarshal([]byte(tt.input), &e)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.checkIsZero {
assert.True(t, e.IsZero())
return
}
if tt.checkNotZero {
assert.False(t, e.IsZero())
}
if tt.checkEmpty {
assert.Empty(t, e.Entries())
return
}
if tt.wantEntryCount > 0 {
assert.Len(t, e.Entries(), tt.wantEntryCount)
}
if tt.wantEntries != nil {
entries := e.Entries()
keys := make(map[string]string)
for _, entry := range entries {
keys[entry.Key] = entry.Value
}
for k, v := range tt.wantEntries {
assert.Equal(t, v, keys[k], "key %s should have value %s", k, v)
}
}
})
}
t.Run("ZeroValue", func(t *testing.T) {
t.Parallel()
var e types.EnvValue
assert.True(t, e.IsZero())
assert.Nil(t, e.Entries())
})
}
func TestEnvValue_InStruct(t *testing.T) {
t.Parallel()
type StepConfig struct {
Name string `yaml:"name"`
Env types.EnvValue `yaml:"env"`
}
tests := []struct {
name string
input string
wantName string
wantEntryCount int
wantIsZero bool
checkNotZero bool
}{
{
name: "EnvSetAsMap",
input: `
name: my-step
env:
DEBUG: "true"
LOG_LEVEL: info
`,
wantName: "my-step",
wantEntryCount: 2,
checkNotZero: true,
},
{
name: "EnvSetAsArray",
input: `
name: my-step
env:
- DEBUG: "true"
- LOG_LEVEL: info
`,
wantEntryCount: 2,
},
{
name: "EnvNotSet",
input: "name: my-step",
wantIsZero: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var cfg StepConfig
err := yaml.Unmarshal([]byte(tt.input), &cfg)
require.NoError(t, err)
if tt.wantName != "" {
assert.Equal(t, tt.wantName, cfg.Name)
}
if tt.wantEntryCount > 0 {
assert.Len(t, cfg.Env.Entries(), tt.wantEntryCount)
}
if tt.wantIsZero {
assert.True(t, cfg.Env.IsZero())
}
if tt.checkNotZero {
assert.False(t, cfg.Env.IsZero())
}
})
}
t.Run("EnvSetAsArrayPreservesOrder", func(t *testing.T) {
t.Parallel()
var cfg StepConfig
err := yaml.Unmarshal([]byte(`
name: my-step
env:
- DEBUG: "true"
- LOG_LEVEL: info
`), &cfg)
require.NoError(t, err)
entries := cfg.Env.Entries()
require.Len(t, entries, 2)
assert.Equal(t, "DEBUG", entries[0].Key)
assert.Equal(t, "LOG_LEVEL", entries[1].Key)
})
}
func TestEnvValue_AdditionalCoverage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
checkIsZero bool
}{
{
name: "InvalidTypeInArrayNumber",
input: "[123]",
wantErr: true,
errContains: "expected map or string",
},
{
name: "InvalidTypeInArrayBoolean",
input: "[true]",
wantErr: true,
errContains: "expected map or string",
},
{
name: "InvalidTypeNumber",
input: "123",
wantErr: true,
errContains: "must be map or array",
},
{
name: "NullValue",
input: "null",
checkIsZero: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var e types.EnvValue
err := yaml.Unmarshal([]byte(tt.input), &e)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.checkIsZero {
assert.True(t, e.IsZero())
}
})
}
t.Run("ValueReturnsRawMap", func(t *testing.T) {
t.Parallel()
var e types.EnvValue
err := yaml.Unmarshal([]byte("KEY: value"), &e)
require.NoError(t, err)
val, ok := e.Value().(map[string]any)
require.True(t, ok)
assert.Equal(t, "value", val["KEY"])
})
t.Run("ValueReturnsRawArray", func(t *testing.T) {
t.Parallel()
var e types.EnvValue
err := yaml.Unmarshal([]byte("[KEY=value]"), &e)
require.NoError(t, err)
val, ok := e.Value().([]any)
require.True(t, ok)
assert.Len(t, val, 1)
})
}

View File

@ -0,0 +1,77 @@
package types
import (
"fmt"
"strings"
"github.com/dagu-org/dagu/internal/core"
"github.com/goccy/go-yaml"
)
// LogOutputValue represents a log output configuration that can be unmarshaled from YAML.
// It accepts a string value that must be one of: "separate" or "merged".
// This type uses core.LogOutputMode to avoid type duplication.
type LogOutputValue struct {
mode core.LogOutputMode
set bool // whether the value was explicitly set in YAML
}
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
func (l *LogOutputValue) UnmarshalYAML(data []byte) error {
l.set = true
var raw any
if err := yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("logOutput unmarshal error: %w", err)
}
switch v := raw.(type) {
case string:
value := strings.TrimSpace(strings.ToLower(v))
switch value {
case "separate", "":
l.mode = core.LogOutputSeparate
case "merged":
l.mode = core.LogOutputMerged
default:
return fmt.Errorf("invalid logOutput value: %q (must be 'separate' or 'merged')", v)
}
return nil
case nil:
l.set = false
return nil
default:
return fmt.Errorf("logOutput must be a string, got %T", v)
}
}
// IsZero returns true if the value was not set in YAML.
func (l LogOutputValue) IsZero() bool {
return !l.set
}
// Mode returns the log output mode.
// If the value was not set, it returns core.LogOutputSeparate as the default.
func (l LogOutputValue) Mode() core.LogOutputMode {
if !l.set {
return core.LogOutputSeparate
}
return l.mode
}
// String returns the string representation of the log output mode.
func (l LogOutputValue) String() string {
return string(l.Mode())
}
// IsMerged returns true if the log output mode is merged.
func (l LogOutputValue) IsMerged() bool {
return l.mode == core.LogOutputMerged
}
// IsSeparate returns true if the log output mode is separate.
func (l LogOutputValue) IsSeparate() bool {
return !l.set || l.mode == core.LogOutputSeparate
}

View File

@ -0,0 +1,163 @@
package types
import (
"testing"
"github.com/dagu-org/dagu/internal/core"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogOutputValue_UnmarshalYAML(t *testing.T) {
tests := []struct {
name string
input string
wantMode core.LogOutputMode
wantSet bool
wantErr bool
errContains string
}{
{
name: "separate mode",
input: "logOutput: separate",
wantMode: core.LogOutputSeparate,
wantSet: true,
},
{
name: "merged mode",
input: "logOutput: merged",
wantMode: core.LogOutputMerged,
wantSet: true,
},
{
name: "merged mode uppercase",
input: "logOutput: MERGED",
wantMode: core.LogOutputMerged,
wantSet: true,
},
{
name: "separate mode mixed case",
input: "logOutput: Separate",
wantMode: core.LogOutputSeparate,
wantSet: true,
},
{
name: "empty string defaults to separate",
input: "logOutput: ''",
wantMode: core.LogOutputSeparate,
wantSet: true,
},
{
name: "invalid value",
input: "logOutput: invalid",
wantErr: true,
errContains: "invalid logOutput value",
},
{
name: "invalid value - both",
input: "logOutput: both",
wantErr: true,
errContains: "invalid logOutput value",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result struct {
LogOutput LogOutputValue `yaml:"logOutput"`
}
err := yaml.Unmarshal([]byte(tt.input), &result)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantMode, result.LogOutput.Mode())
assert.Equal(t, tt.wantSet, !result.LogOutput.IsZero())
})
}
}
func TestLogOutputValue_UnmarshalYAML_InvalidType(t *testing.T) {
var result struct {
LogOutput LogOutputValue `yaml:"logOutput"`
}
err := yaml.Unmarshal([]byte("logOutput:\n - item1\n - item2"), &result)
require.Error(t, err)
assert.Contains(t, err.Error(), "must be a string")
}
func TestLogOutputValue_DefaultValue(t *testing.T) {
var result struct {
LogOutput LogOutputValue `yaml:"logOutput"`
}
// When not set in YAML
err := yaml.Unmarshal([]byte("other: value"), &result)
require.NoError(t, err)
// Should be zero
assert.True(t, result.LogOutput.IsZero())
// Default mode should be separate
assert.Equal(t, core.LogOutputSeparate, result.LogOutput.Mode())
assert.True(t, result.LogOutput.IsSeparate())
assert.False(t, result.LogOutput.IsMerged())
}
func TestLogOutputValue_Methods(t *testing.T) {
t.Run("IsMerged", func(t *testing.T) {
merged := LogOutputValue{mode: core.LogOutputMerged, set: true}
assert.True(t, merged.IsMerged())
assert.False(t, merged.IsSeparate())
separate := LogOutputValue{mode: core.LogOutputSeparate, set: true}
assert.False(t, separate.IsMerged())
assert.True(t, separate.IsSeparate())
})
t.Run("String", func(t *testing.T) {
merged := LogOutputValue{mode: core.LogOutputMerged, set: true}
assert.Equal(t, "merged", merged.String())
separate := LogOutputValue{mode: core.LogOutputSeparate, set: true}
assert.Equal(t, "separate", separate.String())
unset := LogOutputValue{}
assert.Equal(t, "separate", unset.String()) // default
})
}
func TestLogOutputMode_Constants(t *testing.T) {
// Ensure constants have expected values
assert.Equal(t, core.LogOutputMode("separate"), core.LogOutputSeparate)
assert.Equal(t, core.LogOutputMode("merged"), core.LogOutputMerged)
}
func TestLogOutputValue_UnmarshalYAML_NilValue(t *testing.T) {
var result struct {
LogOutput LogOutputValue `yaml:"logOutput"`
}
// Explicit null value in YAML
err := yaml.Unmarshal([]byte("logOutput: null"), &result)
require.NoError(t, err)
assert.True(t, result.LogOutput.IsZero())
assert.Equal(t, core.LogOutputSeparate, result.LogOutput.Mode())
}
func TestLogOutputValue_UnmarshalYAML_MapValue(t *testing.T) {
var result struct {
LogOutput LogOutputValue `yaml:"logOutput"`
}
// Map value should fail with "must be a string" error
err := yaml.Unmarshal([]byte("logOutput:\n key: value"), &result)
require.Error(t, err)
assert.Contains(t, err.Error(), "must be a string")
}

View File

@ -0,0 +1,68 @@
package types
import (
"fmt"
"github.com/goccy/go-yaml"
)
// PortValue represents a port number that can be specified as either
// a string or an integer.
//
// YAML examples:
//
// port: 22
// port: "22"
type PortValue struct {
raw any // Original value for error reporting
isSet bool // Whether the field was set in YAML
value string // Normalized string value
}
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
func (p *PortValue) UnmarshalYAML(data []byte) error {
p.isSet = true
var raw any
if err := yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("port unmarshal error: %w", err)
}
p.raw = raw
switch v := raw.(type) {
case string:
p.value = v
return nil
case int:
p.value = fmt.Sprintf("%d", v)
return nil
case float64:
if v != float64(int(v)) {
return fmt.Errorf("port must be an integer, got %v", v)
}
p.value = fmt.Sprintf("%d", int(v))
return nil
case uint64:
p.value = fmt.Sprintf("%d", v)
return nil
case nil:
p.isSet = false
return nil
default:
return fmt.Errorf("port must be string or number, got %T", v)
}
}
// IsZero returns true if the port was not set in YAML.
func (p PortValue) IsZero() bool { return !p.isSet }
// Value returns the original raw value for error reporting.
func (p PortValue) Value() any { return p.raw }
// String returns the port as a string.
func (p PortValue) String() string { return p.value }

View File

@ -0,0 +1,231 @@
package types_test
import (
"testing"
"github.com/dagu-org/dagu/internal/core/spec/types"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPortValue_UnmarshalYAML(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantString string
checkNotZero bool
}{
{
name: "Integer",
input: "22",
wantString: "22",
checkNotZero: true,
},
{
name: "String",
input: `"8080"`,
wantString: "8080",
},
{
name: "LargePortNumber",
input: "65535",
wantString: "65535",
},
{
name: "InvalidTypeArray",
input: "[22, 80]",
wantErr: true,
errContains: "must be string or number",
},
{
name: "InvalidTypeMap",
input: "{port: 22}",
wantErr: true,
errContains: "must be string or number",
},
{
name: "FloatWithDecimal",
input: "22.5",
wantErr: true,
errContains: "port must be an integer",
},
{
name: "FloatWholeNumber",
input: "22.0",
wantString: "22",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var p types.PortValue
err := yaml.Unmarshal([]byte(tt.input), &p)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantString, p.String())
if tt.checkNotZero {
assert.False(t, p.IsZero())
}
})
}
t.Run("ZeroValue", func(t *testing.T) {
t.Parallel()
var p types.PortValue
assert.True(t, p.IsZero())
assert.Equal(t, "", p.String())
})
}
func TestPortValue_InStruct(t *testing.T) {
t.Parallel()
type SSHConfig struct {
Host string `yaml:"host"`
Port types.PortValue `yaml:"port"`
User string `yaml:"user"`
}
sshTests := []struct {
name string
input string
wantHost string
wantPort string
wantUser string
wantIsZero bool
}{
{
name: "PortAsInteger",
input: `
host: example.com
port: 22
user: admin
`,
wantHost: "example.com",
wantPort: "22",
wantUser: "admin",
},
{
name: "PortAsString",
input: `
host: example.com
port: "2222"
user: admin
`,
wantPort: "2222",
},
{
name: "PortNotSet",
input: `
host: example.com
user: admin
`,
wantIsZero: true,
},
}
for _, tt := range sshTests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var cfg SSHConfig
err := yaml.Unmarshal([]byte(tt.input), &cfg)
require.NoError(t, err)
if tt.wantHost != "" {
assert.Equal(t, tt.wantHost, cfg.Host)
}
if tt.wantPort != "" {
assert.Equal(t, tt.wantPort, cfg.Port.String())
}
if tt.wantUser != "" {
assert.Equal(t, tt.wantUser, cfg.User)
}
if tt.wantIsZero {
assert.True(t, cfg.Port.IsZero())
}
})
}
t.Run("SMTPPort", func(t *testing.T) {
t.Parallel()
type SMTPConfig struct {
Host string `yaml:"host"`
Port types.PortValue `yaml:"port"`
Username string `yaml:"username"`
}
var cfg SMTPConfig
err := yaml.Unmarshal([]byte(`
host: smtp.example.com
port: 587
username: user
`), &cfg)
require.NoError(t, err)
assert.Equal(t, "587", cfg.Port.String())
})
}
func TestPortValue_AdditionalCoverage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantValueNotNil bool
wantValue any
wantIsZero bool
wantString string
}{
{
name: "ValueReturnsRawInt",
input: "22",
wantValueNotNil: true,
},
{
name: "ValueReturnsRawString",
input: `"22"`,
wantValue: "22",
},
{
name: "NullValueSetsIsZeroFalse",
input: "null",
wantIsZero: true,
},
{
name: "LargeInteger",
input: "99999",
wantString: "99999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var p types.PortValue
err := yaml.Unmarshal([]byte(tt.input), &p)
require.NoError(t, err)
if tt.wantValueNotNil {
assert.NotNil(t, p.Value())
}
if tt.wantValue != nil {
assert.Equal(t, tt.wantValue, p.Value())
}
if tt.wantIsZero {
assert.True(t, p.IsZero())
}
if tt.wantString != "" {
assert.Equal(t, tt.wantString, p.String())
}
})
}
}

View File

@ -0,0 +1,137 @@
package types
import (
"fmt"
"github.com/goccy/go-yaml"
)
// ScheduleValue represents a schedule configuration that can be specified as:
// - A single cron expression string
// - An array of cron expressions
// - A map with start/stop/restart keys
//
// YAML examples:
//
// schedule: "0 * * * *"
// schedule: ["0 * * * *", "30 * * * *"]
// schedule:
// start: "0 8 * * *"
// stop: "0 18 * * *"
// restart: "0 12 * * *"
type ScheduleValue struct {
raw any // Original value for error reporting
isSet bool // Whether the field was set in YAML
starts []string // Start schedules (or simple schedule expressions)
stops []string // Stop schedules
restarts []string // Restart schedules
}
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
func (s *ScheduleValue) UnmarshalYAML(data []byte) error {
s.isSet = true
var raw any
if err := yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("schedule unmarshal error: %w", err)
}
s.raw = raw
switch v := raw.(type) {
case string:
// Single cron expression
if v != "" {
s.starts = []string{v}
}
return nil
case []any:
// Array of cron expressions
for i, item := range v {
str, ok := item.(string)
if !ok {
return fmt.Errorf("schedule[%d]: expected string, got %T", i, item)
}
s.starts = append(s.starts, str)
}
return nil
case []string:
// Array of strings (from Go types)
s.starts = v
return nil
case map[string]any:
// Map with start/stop/restart keys
return s.parseScheduleMap(v)
case nil:
// nil is valid, just means not set
s.isSet = false
return nil
default:
return fmt.Errorf("schedule must be string, array, or map, got %T", v)
}
}
func (s *ScheduleValue) parseScheduleMap(m map[string]any) error {
for key, v := range m {
values, err := parseScheduleEntry(v)
if err != nil {
return fmt.Errorf("schedule.%s: %w", key, err)
}
switch key {
case "start":
s.starts = values
case "stop":
s.stops = values
case "restart":
s.restarts = values
default:
return fmt.Errorf("schedule: unknown key %q (expected start, stop, or restart)", key)
}
}
return nil
}
func parseScheduleEntry(v any) ([]string, error) {
switch val := v.(type) {
case string:
return []string{val}, nil
case []any:
var result []string
for i, item := range val {
str, ok := item.(string)
if !ok {
return nil, fmt.Errorf("[%d]: expected string, got %T", i, item)
}
result = append(result, str)
}
return result, nil
default:
return nil, fmt.Errorf("expected string or array, got %T", v)
}
}
// IsZero returns true if the schedule was not set in YAML.
func (s ScheduleValue) IsZero() bool { return !s.isSet }
// Value returns the original raw value for error reporting.
func (s ScheduleValue) Value() any { return s.raw }
// Starts returns the start/simple schedules.
func (s ScheduleValue) Starts() []string { return s.starts }
// Stops returns the stop schedules.
func (s ScheduleValue) Stops() []string { return s.stops }
// Restarts returns the restart schedules.
func (s ScheduleValue) Restarts() []string { return s.restarts }
// HasStopSchedule returns true if stop schedules are configured.
func (s ScheduleValue) HasStopSchedule() bool { return len(s.stops) > 0 }
// HasRestartSchedule returns true if restart schedules are configured.
func (s ScheduleValue) HasRestartSchedule() bool { return len(s.restarts) > 0 }

View File

@ -0,0 +1,305 @@
package types_test
import (
"testing"
"github.com/dagu-org/dagu/internal/core/spec/types"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestScheduleValue_UnmarshalYAML(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantStarts []string
wantStops []string
wantRestarts []string
wantHasStop bool
wantHasRestart bool
checkHasStop bool
checkHasRestart bool
}{
{
name: "SingleCronExpression",
input: `"0 * * * *"`,
wantStarts: []string{"0 * * * *"},
checkHasStop: true,
wantHasStop: false,
checkHasRestart: true,
wantHasRestart: false,
},
{
name: "ArrayOfCronExpressions",
input: `["0 * * * *", "30 * * * *"]`,
wantStarts: []string{"0 * * * *", "30 * * * *"},
},
{
name: "MultilineArray",
input: `
- "0 8 * * *"
- "0 12 * * *"
- "0 18 * * *"
`,
wantStarts: []string{"0 8 * * *", "0 12 * * *", "0 18 * * *"},
},
{
name: "MapWithStartOnly",
input: `start: "0 8 * * *"`,
wantStarts: []string{"0 8 * * *"},
checkHasStop: true,
wantHasStop: false,
},
{
name: "MapWithStartAndStop",
input: `
start: "0 8 * * *"
stop: "0 18 * * *"
`,
wantStarts: []string{"0 8 * * *"},
wantStops: []string{"0 18 * * *"},
checkHasStop: true,
wantHasStop: true,
},
{
name: "MapWithAllKeys",
input: `
start: "0 8 * * *"
stop: "0 18 * * *"
restart: "0 0 * * *"
`,
wantStarts: []string{"0 8 * * *"},
wantStops: []string{"0 18 * * *"},
wantRestarts: []string{"0 0 * * *"},
checkHasRestart: true,
wantHasRestart: true,
},
{
name: "MapWithArrayValues",
input: `
start:
- "0 8 * * *"
- "0 12 * * *"
stop: "0 18 * * *"
`,
wantStarts: []string{"0 8 * * *", "0 12 * * *"},
wantStops: []string{"0 18 * * *"},
},
{
name: "InvalidMapKey",
input: `invalid: "0 * * * *"`,
wantErr: true,
errContains: "unknown key",
},
{
name: "InvalidArrayElementType",
input: `
start:
- 123
- 456
`,
wantErr: true,
errContains: "expected string",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var s types.ScheduleValue
err := yaml.Unmarshal([]byte(tt.input), &s)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.wantStarts != nil {
assert.Equal(t, tt.wantStarts, s.Starts())
}
if tt.wantStops != nil {
assert.Equal(t, tt.wantStops, s.Stops())
}
if tt.wantRestarts != nil {
assert.Equal(t, tt.wantRestarts, s.Restarts())
}
if tt.checkHasStop {
assert.Equal(t, tt.wantHasStop, s.HasStopSchedule())
}
if tt.checkHasRestart {
assert.Equal(t, tt.wantHasRestart, s.HasRestartSchedule())
}
})
}
t.Run("ZeroValue", func(t *testing.T) {
t.Parallel()
var s types.ScheduleValue
assert.True(t, s.IsZero())
assert.Nil(t, s.Starts())
})
}
func TestScheduleValue_InStruct(t *testing.T) {
t.Parallel()
type DAGConfig struct {
Name string `yaml:"name"`
Schedule types.ScheduleValue `yaml:"schedule"`
}
tests := []struct {
name string
input string
wantName string
wantStarts []string
wantStops []string
wantIsZero bool
}{
{
name: "SimpleSchedule",
input: `
name: my-dag
schedule: "0 * * * *"
`,
wantName: "my-dag",
wantStarts: []string{"0 * * * *"},
},
{
name: "ComplexSchedule",
input: `
name: my-dag
schedule:
start: "0 8 * * 1-5"
stop: "0 18 * * 1-5"
`,
wantStarts: []string{"0 8 * * 1-5"},
wantStops: []string{"0 18 * * 1-5"},
},
{
name: "NoSchedule",
input: "name: my-dag",
wantIsZero: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var cfg DAGConfig
err := yaml.Unmarshal([]byte(tt.input), &cfg)
require.NoError(t, err)
if tt.wantName != "" {
assert.Equal(t, tt.wantName, cfg.Name)
}
if tt.wantStarts != nil {
assert.Equal(t, tt.wantStarts, cfg.Schedule.Starts())
}
if tt.wantStops != nil {
assert.Equal(t, tt.wantStops, cfg.Schedule.Stops())
}
if tt.wantIsZero {
assert.True(t, cfg.Schedule.IsZero())
}
})
}
}
func TestScheduleValue_AdditionalCoverage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantValue any
checkIsZero bool
wantStarts []string
}{
{
name: "ValueReturnsRawString",
input: `"0 * * * *"`,
wantValue: "0 * * * *",
},
{
name: "NullValue",
input: "null",
checkIsZero: true,
},
{
name: "EmptyString",
input: `""`,
wantStarts: nil,
},
{
name: "InvalidTypeNumber",
input: "123",
wantErr: true,
errContains: "must be string, array, or map",
},
{
name: "InvalidScheduleEntryTypeInMap",
input: "start: 123",
wantErr: true,
errContains: "expected string or array",
},
{
name: "InvalidTypeInStartArray",
input: `["0 * * * *", 123]`,
wantErr: true,
errContains: "expected string",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var s types.ScheduleValue
err := yaml.Unmarshal([]byte(tt.input), &s)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.wantValue != nil {
assert.Equal(t, tt.wantValue, s.Value())
}
if tt.checkIsZero {
assert.True(t, s.IsZero())
}
if tt.wantStarts != nil {
assert.Equal(t, tt.wantStarts, s.Starts())
}
})
}
t.Run("ValueReturnsRawArray", func(t *testing.T) {
t.Parallel()
var s types.ScheduleValue
err := yaml.Unmarshal([]byte(`["0 * * * *"]`), &s)
require.NoError(t, err)
val, ok := s.Value().([]any)
require.True(t, ok)
assert.Len(t, val, 1)
})
t.Run("EmptyStringNotZero", func(t *testing.T) {
t.Parallel()
var s types.ScheduleValue
err := yaml.Unmarshal([]byte(`""`), &s)
require.NoError(t, err)
assert.False(t, s.IsZero())
assert.Nil(t, s.Starts())
})
}

View File

@ -0,0 +1,99 @@
package types
import (
"fmt"
"strings"
"github.com/goccy/go-yaml"
)
// ShellValue represents a shell configuration that can be specified as either
// a string (e.g., "bash -e") or an array (e.g., ["bash", "-e"]).
//
// YAML examples:
//
// shell: "bash -e"
// shell: bash
// shell: ["bash", "-e"]
// shell:
// - bash
// - -e
type ShellValue struct {
raw any // Original value for error reporting
isSet bool // Whether the field was set in YAML
command string // The shell command (first element)
arguments []string // Additional arguments
}
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
func (s *ShellValue) UnmarshalYAML(data []byte) error {
s.isSet = true
var raw any
if err := yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("shell unmarshal error: %w", err)
}
s.raw = raw
switch v := raw.(type) {
case string:
// String value: "bash -e" or "bash"
s.command = strings.TrimSpace(v)
return nil
case []any:
// Array value: ["bash", "-e"]
if len(v) > 0 {
if cmd, ok := v[0].(string); ok {
s.command = cmd
} else {
s.command = fmt.Sprintf("%v", v[0])
}
for i := 1; i < len(v); i++ {
if arg, ok := v[i].(string); ok {
s.arguments = append(s.arguments, arg)
} else {
s.arguments = append(s.arguments, fmt.Sprintf("%v", v[i]))
}
}
}
return nil
case []string:
// Array of strings (from Go types)
if len(v) > 0 {
s.command = v[0]
s.arguments = v[1:]
}
return nil
case nil:
s.isSet = false
return nil
default:
return fmt.Errorf("shell must be string or array, got %T", v)
}
}
// IsZero returns true if the shell value was not set in YAML.
func (s ShellValue) IsZero() bool { return !s.isSet }
// Value returns the original raw value for error reporting.
func (s ShellValue) Value() any { return s.raw }
// Command returns the shell command (first element or full string).
func (s ShellValue) Command() string { return s.command }
// Arguments returns the additional shell arguments (only set for array form).
func (s ShellValue) Arguments() []string { return s.arguments }
// IsArray returns true if the value was specified as an array.
func (s ShellValue) IsArray() bool {
switch s.raw.(type) {
case []any, []string:
return true
default:
return false
}
}

View File

@ -0,0 +1,269 @@
package types_test
import (
"testing"
"github.com/dagu-org/dagu/internal/core/spec/types"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShellValue_UnmarshalYAML(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantCommand string
wantArgs []string
wantIsArray bool
checkNotZero bool
checkIsZero bool
}{
{
name: "StringWithoutArgs",
input: "bash",
wantCommand: "bash",
wantArgs: nil,
wantIsArray: false,
checkNotZero: true,
},
{
name: "StringWithArgsInline",
input: `"bash -e"`,
wantCommand: "bash -e",
wantArgs: nil,
},
{
name: "ArrayFormInline",
input: `["bash", "-e", "-x"]`,
wantCommand: "bash",
wantArgs: []string{"-e", "-x"},
wantIsArray: true,
},
{
name: "MultilineArrayForm",
input: "- bash\n- -e\n- -x",
wantCommand: "bash",
wantArgs: []string{"-e", "-x"},
},
{
name: "EmptyString",
input: `""`,
wantCommand: "",
checkNotZero: true,
},
{
name: "EmptyArray",
input: "[]",
wantCommand: "",
wantArgs: nil,
},
{
name: "InvalidTypeMap",
input: "{key: value}",
wantErr: true,
errContains: "must be string or array",
},
{
name: "ShellWithEnvVariableSyntax",
input: `"${SHELL}"`,
wantCommand: "${SHELL}",
},
{
name: "NixShellExample",
input: `["nix-shell", "-p", "python3"]`,
wantCommand: "nix-shell",
wantArgs: []string{"-p", "python3"},
},
{
name: "NullValue",
input: "null",
checkIsZero: true,
},
{
name: "ArrayWithNonStringItems",
input: "[123, true]",
wantCommand: "123",
wantArgs: []string{"true"},
},
{
name: "SingleElementArray",
input: `["bash"]`,
wantCommand: "bash",
wantArgs: nil,
wantIsArray: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var s types.ShellValue
err := yaml.Unmarshal([]byte(tt.input), &s)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.checkIsZero {
assert.True(t, s.IsZero())
return
}
if tt.checkNotZero {
assert.False(t, s.IsZero())
}
assert.Equal(t, tt.wantCommand, s.Command())
if tt.wantArgs != nil {
assert.Equal(t, tt.wantArgs, s.Arguments())
} else {
assert.Empty(t, s.Arguments())
}
if tt.wantIsArray {
assert.True(t, s.IsArray())
}
})
}
t.Run("ZeroValue", func(t *testing.T) {
t.Parallel()
var s types.ShellValue
assert.True(t, s.IsZero())
})
}
func TestShellValue_InStruct(t *testing.T) {
t.Parallel()
type Config struct {
Shell types.ShellValue `yaml:"shell"`
Name string `yaml:"name"`
}
tests := []struct {
name string
input string
wantName string
wantCommand string
wantIsZero bool
checkNotZero bool
}{
{
name: "ShellSet",
input: `
name: test
shell: bash
`,
wantName: "test",
wantCommand: "bash",
checkNotZero: true,
},
{
name: "ShellNotSet",
input: "name: test",
wantIsZero: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var cfg Config
err := yaml.Unmarshal([]byte(tt.input), &cfg)
require.NoError(t, err)
if tt.wantName != "" {
assert.Equal(t, tt.wantName, cfg.Name)
}
if tt.wantCommand != "" {
assert.Equal(t, tt.wantCommand, cfg.Shell.Command())
}
if tt.wantIsZero {
assert.True(t, cfg.Shell.IsZero())
}
if tt.checkNotZero {
assert.False(t, cfg.Shell.IsZero())
}
})
}
}
func TestShellValue_AdditionalCoverage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantValue any
checkIsArray bool
wantIsArray bool
checkIsZero bool
}{
{
name: "ValueReturnsRawString",
input: "bash",
wantValue: "bash",
},
{
name: "InvalidTypeNumber",
input: "123",
wantErr: true,
errContains: "must be string or array",
},
{
name: "IsArrayReturnsFalseForString",
input: `"bash -e"`,
checkIsArray: true,
wantIsArray: false,
},
{
name: "IsArrayReturnsFalseForNil",
input: "null",
checkIsArray: true,
wantIsArray: false,
checkIsZero: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var s types.ShellValue
err := yaml.Unmarshal([]byte(tt.input), &s)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.wantValue != nil {
assert.Equal(t, tt.wantValue, s.Value())
}
if tt.checkIsArray {
assert.Equal(t, tt.wantIsArray, s.IsArray())
}
if tt.checkIsZero {
assert.True(t, s.IsZero())
}
})
}
t.Run("ValueReturnsRawArray", func(t *testing.T) {
t.Parallel()
var s types.ShellValue
err := yaml.Unmarshal([]byte(`["bash", "-e"]`), &s)
require.NoError(t, err)
val, ok := s.Value().([]any)
require.True(t, ok)
assert.Len(t, val, 2)
})
}

View File

@ -0,0 +1,90 @@
package types
import (
"fmt"
"github.com/goccy/go-yaml"
)
// StringOrArray represents a value that can be specified as either a single
// string or an array of strings.
//
// YAML examples:
//
// depends: "step1"
// depends: ["step1", "step2"]
// dotenv: ".env"
// dotenv: [".env", ".env.local"]
type StringOrArray struct {
raw any // Original value for error reporting
isSet bool // Whether the field was set in YAML
values []string // Parsed values
}
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
func (s *StringOrArray) UnmarshalYAML(data []byte) error {
s.isSet = true
var raw any
if err := yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("unmarshal error: %w", err)
}
s.raw = raw
switch v := raw.(type) {
case string:
// Single string value - preserve empty strings for validation layer to handle
s.values = []string{v}
return nil
case []any:
// Array of values - convert non-strings to strings for compatibility
for _, item := range v {
if str, ok := item.(string); ok {
s.values = append(s.values, str)
} else {
// Stringify non-string items (e.g., numeric tags)
s.values = append(s.values, fmt.Sprintf("%v", item))
}
}
return nil
case []string:
// Array of strings (from Go types)
s.values = v
return nil
case nil:
s.isSet = false
return nil
default:
return fmt.Errorf("must be string or array, got %T", v)
}
}
// IsZero returns true if the value was not set in YAML.
func (s StringOrArray) IsZero() bool { return !s.isSet }
// Value returns the original raw value for error reporting.
func (s StringOrArray) Value() any { return s.raw }
// Values returns the parsed string values.
func (s StringOrArray) Values() []string { return s.values }
// IsEmpty returns true if set but contains no values (empty array).
func (s StringOrArray) IsEmpty() bool { return s.isSet && len(s.values) == 0 }
// MailToValue is an alias for StringOrArray used for email recipients.
// YAML examples:
//
// to: user@example.com
// to: ["user1@example.com", "user2@example.com"]
type MailToValue = StringOrArray
// TagsValue is an alias for StringOrArray used for tags.
// YAML examples:
//
// tags: production
// tags: ["production", "critical"]
type TagsValue = StringOrArray

View File

@ -0,0 +1,309 @@
package types_test
import (
"testing"
"github.com/dagu-org/dagu/internal/core/spec/types"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStringOrArray_UnmarshalYAML(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantValues []string
checkIsEmpty bool
wantIsEmpty bool
checkNotZero bool
}{
{
name: "SingleString",
input: "step1",
wantValues: []string{"step1"},
checkNotZero: true,
},
{
name: "ArrayOfStringsInline",
input: `["step1", "step2", "step3"]`,
wantValues: []string{"step1", "step2", "step3"},
},
{
name: "MultilineArray",
input: "- step1\n- step2",
wantValues: []string{"step1", "step2"},
},
{
name: "EmptyString",
input: `""`,
wantValues: []string{""},
checkIsEmpty: true,
wantIsEmpty: false,
checkNotZero: true,
},
{
name: "EmptyArray",
input: "[]",
wantValues: nil,
checkIsEmpty: true,
wantIsEmpty: true,
},
{
name: "InvalidTypeMap",
input: "{key: value}",
wantErr: true,
errContains: "must be string or array",
},
{
name: "QuotedStringWithSpaces",
input: `"step with spaces"`,
wantValues: []string{"step with spaces"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var s types.StringOrArray
err := yaml.Unmarshal([]byte(tt.input), &s)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.wantValues == nil {
assert.Empty(t, s.Values())
} else {
assert.Equal(t, tt.wantValues, s.Values())
}
if tt.checkIsEmpty {
assert.Equal(t, tt.wantIsEmpty, s.IsEmpty())
}
if tt.checkNotZero {
assert.False(t, s.IsZero())
}
})
}
t.Run("ZeroValue", func(t *testing.T) {
t.Parallel()
var s types.StringOrArray
assert.True(t, s.IsZero())
assert.Nil(t, s.Values())
})
}
func TestStringOrArray_InStruct(t *testing.T) {
t.Parallel()
type StepConfig struct {
Name string `yaml:"name"`
Depends types.StringOrArray `yaml:"depends"`
}
tests := []struct {
name string
input string
wantValues []string
wantIsZero bool
checkIsEmpty bool
wantIsEmpty bool
}{
{
name: "DependsAsString",
input: `
name: step2
depends: step1
`,
wantValues: []string{"step1"},
},
{
name: "DependsAsArray",
input: `
name: step3
depends:
- step1
- step2
`,
wantValues: []string{"step1", "step2"},
},
{
name: "DependsNotSet",
input: "name: step1",
wantIsZero: true,
},
{
name: "DependsEmptyArray",
input: `
name: step2
depends: []
`,
checkIsEmpty: true,
wantIsEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var cfg StepConfig
err := yaml.Unmarshal([]byte(tt.input), &cfg)
require.NoError(t, err)
if tt.wantValues != nil {
assert.Equal(t, tt.wantValues, cfg.Depends.Values())
}
if tt.wantIsZero {
assert.True(t, cfg.Depends.IsZero())
}
if tt.checkIsEmpty {
assert.False(t, cfg.Depends.IsZero())
assert.Equal(t, tt.wantIsEmpty, cfg.Depends.IsEmpty())
}
})
}
}
func TestMailToValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantValues []string
}{
{
name: "SingleEmail",
input: "user@example.com",
wantValues: []string{"user@example.com"},
},
{
name: "MultipleEmails",
input: `["user1@example.com", "user2@example.com"]`,
wantValues: []string{"user1@example.com", "user2@example.com"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var m types.MailToValue
err := yaml.Unmarshal([]byte(tt.input), &m)
require.NoError(t, err)
assert.Equal(t, tt.wantValues, m.Values())
})
}
}
func TestTagsValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantValues []string
}{
{
name: "SingleTag",
input: "production",
wantValues: []string{"production"},
},
{
name: "MultipleTags",
input: `["production", "critical", "monitored"]`,
wantValues: []string{"production", "critical", "monitored"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var tags types.TagsValue
err := yaml.Unmarshal([]byte(tt.input), &tags)
require.NoError(t, err)
assert.Equal(t, tt.wantValues, tags.Values())
})
}
}
func TestStringOrArray_AdditionalCoverage(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
errContains string
wantValues []string
checkIsZero bool
}{
{
name: "ArrayWithNumericValues",
input: "[1, 2, 3]",
wantValues: []string{"1", "2", "3"},
},
{
name: "ArrayWithMixedTypes",
input: `["step1", 123, true]`,
wantValues: []string{"step1", "123", "true"},
},
{
name: "InvalidTypeNumber",
input: "123",
wantErr: true,
errContains: "must be string or array",
},
{
name: "NullValue",
input: "null",
checkIsZero: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var s types.StringOrArray
err := yaml.Unmarshal([]byte(tt.input), &s)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
require.NoError(t, err)
if tt.wantValues != nil {
assert.Equal(t, tt.wantValues, s.Values())
}
if tt.checkIsZero {
assert.True(t, s.IsZero())
}
})
}
t.Run("ValueReturnsRawString", func(t *testing.T) {
t.Parallel()
var s types.StringOrArray
err := yaml.Unmarshal([]byte("step1"), &s)
require.NoError(t, err)
assert.Equal(t, "step1", s.Value())
})
t.Run("ValueReturnsRawArray", func(t *testing.T) {
t.Parallel()
var s types.StringOrArray
err := yaml.Unmarshal([]byte(`["step1", "step2"]`), &s)
require.NoError(t, err)
val, ok := s.Value().([]any)
require.True(t, ok)
assert.Len(t, val, 2)
})
}

View File

@ -7,6 +7,7 @@ import (
"github.com/dagu-org/dagu/internal/common/cmdutil"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/spec/types"
)
// loadVariables loads the environment variables from the map.
@ -89,6 +90,43 @@ func loadVariables(ctx BuildContext, strVariables any) (
return vars, nil
}
// loadVariablesFromEnvValue loads environment variables from a types.EnvValue.
// This function converts the typed EnvValue entries to the expected format
// and processes them using the same logic as loadVariables.
func loadVariablesFromEnvValue(ctx BuildContext, env types.EnvValue) (
map[string]string, error,
) {
if env.IsZero() {
return nil, nil
}
vars := map[string]string{}
for _, entry := range env.Entries() {
value := entry.Value
if !ctx.opts.Has(BuildFlagNoEval) {
// Evaluate the value of the environment variable.
// This also executes command substitution.
// Pass accumulated vars so ${VAR} can reference previously defined vars
var err error
value, err = cmdutil.EvalString(ctx.ctx, value, cmdutil.WithVariables(vars))
if err != nil {
return nil, core.NewValidationError("env", entry.Value, fmt.Errorf("%w: %s", ErrInvalidEnvValue, entry.Value))
}
// Set the environment variable.
if err := os.Setenv(entry.Key, value); err != nil {
return nil, core.NewValidationError("env", entry.Key, fmt.Errorf("%w: %s", err, entry.Key))
}
}
vars[entry.Key] = value
}
return vars, nil
}
// pair represents a key-value pair.
type pair struct {
key string

View File

@ -0,0 +1,420 @@
package spec
import (
"context"
"testing"
"github.com/dagu-org/dagu/internal/core/spec/types"
"github.com/goccy/go-yaml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseKeyValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input map[string]any
expected []pair
}{
{
name: "EmptyMap",
input: map[string]any{},
expected: nil,
},
{
name: "SingleStringValue",
input: map[string]any{"FOO": "bar"},
expected: []pair{{key: "FOO", val: "bar"}},
},
{
name: "IntegerValue",
input: map[string]any{"COUNT": 42},
expected: []pair{{key: "COUNT", val: "42"}},
},
{
name: "BooleanValue",
input: map[string]any{"DEBUG": true},
expected: []pair{{key: "DEBUG", val: "true"}},
},
{
name: "FloatValue",
input: map[string]any{"RATIO": 3.14},
expected: []pair{{key: "RATIO", val: "3.14"}},
},
{
name: "MultipleValues",
input: map[string]any{"A": "1", "B": "2"},
expected: []pair{
{key: "A", val: "1"},
{key: "B", val: "2"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var pairs []pair
err := parseKeyValue(tt.input, &pairs)
require.NoError(t, err)
if tt.expected == nil {
assert.Empty(t, pairs)
return
}
// Since map iteration order is not guaranteed, check by content
assert.Len(t, pairs, len(tt.expected))
for _, exp := range tt.expected {
found := false
for _, p := range pairs {
if p.key == exp.key && p.val == exp.val {
found = true
break
}
}
assert.True(t, found, "expected pair %v not found", exp)
}
})
}
}
func TestLoadVariables(t *testing.T) {
t.Parallel()
t.Run("MapInput", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input map[string]any
expected map[string]string
}{
{
name: "SingleVariable",
input: map[string]any{"FOO": "bar"},
expected: map[string]string{"FOO": "bar"},
},
{
name: "MultipleVariables",
input: map[string]any{"A": "1", "B": "2"},
expected: map[string]string{"A": "1", "B": "2"},
},
{
name: "IntegerValue",
input: map[string]any{"PORT": 8080},
expected: map[string]string{"PORT": "8080"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
result, err := loadVariables(ctx, tt.input)
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
})
t.Run("ArrayInput", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []any
expected map[string]string
}{
{
name: "ArrayOfMaps",
input: []any{
map[string]any{"FOO": "bar"},
map[string]any{"BAZ": "qux"},
},
expected: map[string]string{"FOO": "bar", "BAZ": "qux"},
},
{
name: "ArrayOfStrings",
input: []any{
"FOO=bar",
"BAZ=qux",
},
expected: map[string]string{"FOO": "bar", "BAZ": "qux"},
},
{
name: "MixedArrayOfMapsAndStrings",
input: []any{
map[string]any{"FOO": "bar"},
"BAZ=qux",
},
expected: map[string]string{"FOO": "bar", "BAZ": "qux"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
result, err := loadVariables(ctx, tt.input)
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
})
t.Run("ErrorCases", func(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input any
errContains string
}{
{
name: "InvalidStringFormat",
input: []any{"INVALID_NO_EQUALS"},
errContains: "env config should be map of strings or array of key=value",
},
{
name: "InvalidTypeInArray",
input: []any{123},
errContains: "env config should be map of strings or array of key=value",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
_, err := loadVariables(ctx, tt.input)
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errContains)
})
}
})
t.Run("WithEvaluation", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{},
}
input := map[string]any{"GREETING": "hello"}
result, err := loadVariables(ctx, input)
require.NoError(t, err)
assert.Equal(t, "hello", result["GREETING"])
})
t.Run("NoEvalFlag", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
// With NoEval, command substitution should not be executed
input := map[string]any{"CMD": "$(echo hello)"}
result, err := loadVariables(ctx, input)
require.NoError(t, err)
assert.Equal(t, "$(echo hello)", result["CMD"])
})
t.Run("VariableReference", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{},
}
// Test that later variables can reference earlier ones
input := []any{
map[string]any{"BASE": "/opt"},
map[string]any{"PATH_VAR": "${BASE}/bin"},
}
result, err := loadVariables(ctx, input)
require.NoError(t, err)
assert.Equal(t, "/opt", result["BASE"])
assert.Equal(t, "/opt/bin", result["PATH_VAR"])
})
t.Run("EmptyValue", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
input := map[string]any{"EMPTY": ""}
result, err := loadVariables(ctx, input)
require.NoError(t, err)
assert.Equal(t, "", result["EMPTY"])
})
t.Run("ValueWithEqualsSign", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
input := []any{"KEY=value=with=equals"}
result, err := loadVariables(ctx, input)
require.NoError(t, err)
assert.Equal(t, "value=with=equals", result["KEY"])
})
}
// Helper to create EnvValue from YAML string
func envValueFromYAML(t *testing.T, yamlStr string) types.EnvValue {
t.Helper()
var env types.EnvValue
err := yaml.Unmarshal([]byte(yamlStr), &env)
require.NoError(t, err)
return env
}
func TestLoadVariablesFromEnvValue(t *testing.T) {
t.Parallel()
t.Run("EmptyEnvValue", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
var env types.EnvValue
result, err := loadVariablesFromEnvValue(ctx, env)
require.NoError(t, err)
assert.Nil(t, result)
})
t.Run("MapFormat", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
env := envValueFromYAML(t, `
FOO: bar
BAZ: qux
`)
result, err := loadVariablesFromEnvValue(ctx, env)
require.NoError(t, err)
assert.Equal(t, "bar", result["FOO"])
assert.Equal(t, "qux", result["BAZ"])
})
t.Run("ArrayFormat", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
env := envValueFromYAML(t, `
- FOO: bar
- BAZ: qux
`)
result, err := loadVariablesFromEnvValue(ctx, env)
require.NoError(t, err)
assert.Equal(t, "bar", result["FOO"])
assert.Equal(t, "qux", result["BAZ"])
})
t.Run("WithEvaluation", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{},
}
env := envValueFromYAML(t, `
GREETING: hello
`)
result, err := loadVariablesFromEnvValue(ctx, env)
require.NoError(t, err)
assert.Equal(t, "hello", result["GREETING"])
})
t.Run("NoEvalFlag", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
env := envValueFromYAML(t, `
CMD: "$(echo hello)"
`)
result, err := loadVariablesFromEnvValue(ctx, env)
require.NoError(t, err)
assert.Equal(t, "$(echo hello)", result["CMD"])
})
t.Run("VariableReference", func(t *testing.T) {
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{},
}
env := envValueFromYAML(t, `
- BASE: /opt
- PATH_VAR: "${BASE}/bin"
`)
result, err := loadVariablesFromEnvValue(ctx, env)
require.NoError(t, err)
assert.Equal(t, "/opt", result["BASE"])
assert.Equal(t, "/opt/bin", result["PATH_VAR"])
})
t.Run("IntegerValue", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
env := envValueFromYAML(t, `
PORT: 8080
`)
result, err := loadVariablesFromEnvValue(ctx, env)
require.NoError(t, err)
assert.Equal(t, "8080", result["PORT"])
})
t.Run("BooleanValue", func(t *testing.T) {
t.Parallel()
ctx := BuildContext{
ctx: context.Background(),
opts: BuildOpts{Flags: BuildFlagNoEval},
}
env := envValueFromYAML(t, `
DEBUG: true
`)
result, err := loadVariablesFromEnvValue(ctx, env)
require.NoError(t, err)
assert.Equal(t, "true", result["DEBUG"])
})
}

View File

@ -21,30 +21,46 @@ type Step struct {
Shell string `json:"shell,omitempty"`
// ShellPackages is the list of packages to install. This is used only when the shell is `nix-shell`.
ShellPackages []string `json:"shellPackages,omitempty"`
// SHell Args is the list of arguments for the shell program.
// ShellArgs is the list of arguments for the shell program.
ShellArgs []string `json:"shellArgs,omitempty"`
// Dir is the working directory for the step.
Dir string `json:"dir,omitempty"`
// ExecutorConfig contains the configuration for the executor.
ExecutorConfig ExecutorConfig `json:"executorConfig,omitzero"`
// CmdWithArgs is the command with arguments (only display purpose).
// CmdWithArgs is the command with arguments for display purposes.
// Deprecated: Use Commands[0].CmdWithArgs instead. Kept for JSON backward compatibility.
CmdWithArgs string `json:"cmdWithArgs,omitempty"`
// CmdArgsSys is the command with arguments for the system.
// Deprecated: Kept for JSON backward compatibility.
CmdArgsSys string `json:"cmdArgsSys,omitempty"`
// Command specifies only the command without arguments.
// Deprecated: Use Commands field instead. Kept for JSON backward compatibility.
Command string `json:"command,omitempty"`
// ShellCmdArgs is the shell command with arguments.
ShellCmdArgs string `json:"shellCmdArgs,omitempty"`
// Script is the script to be executed.
Script string `json:"script,omitempty"`
// Args contains the arguments for the command.
// Deprecated: Use Commands field instead. Kept for JSON backward compatibility.
Args []string `json:"args,omitempty"`
// Commands is the source of truth for commands to execute.
// Each entry represents a command to be executed sequentially.
// For single commands, this will contain exactly one entry.
Commands []CommandEntry `json:"commands,omitempty"`
// Stdout is the file to store the standard output.
Stdout string `json:"stdout,omitempty"`
// Stderr is the file to store the standard error.
Stderr string `json:"stderr,omitempty"`
// LogOutput specifies how stdout and stderr are handled in log files for this step.
// Overrides the DAG-level LogOutput setting. Empty string means inherit from DAG.
LogOutput LogOutputMode `json:"logOutput,omitempty"`
// Output is the variable name to store the output.
Output string `json:"output,omitempty"`
// OutputKey is the custom key for the output in outputs.json.
// If empty, the Output name is converted from UPPER_CASE to camelCase.
OutputKey string `json:"outputKey,omitempty"`
// OutputOmit excludes this output from outputs.json when true.
OutputOmit bool `json:"outputOmit,omitempty"`
// Depends contains the list of step names to depend on.
Depends []string `json:"depends,omitempty"`
// ExplicitlyNoDeps indicates the depends field was explicitly set to empty
@ -74,6 +90,10 @@ type Step struct {
// Timeout specifies the maximum execution time for the step.
// If set, this timeout takes precedence over the DAG-level timeout for this step.
Timeout time.Duration `json:"timeout,omitempty"`
// Container specifies the container configuration for this step.
// If set, the step runs in its own container instead of the DAG-level container.
// This uses the same configuration format as the DAG-level container field.
Container *Container `json:"container,omitempty"`
}
// String returns a formatted string representation of the step
@ -103,6 +123,68 @@ type SubDAG struct {
Params string `json:"params,omitempty"`
}
// CommandEntry represents a single command in a multi-command step.
// Each entry contains a parsed command with its arguments.
type CommandEntry struct {
// Command is the executable name or path.
Command string `json:"command"`
// Args contains the arguments for the command.
Args []string `json:"args,omitempty"`
// CmdWithArgs is the original command string for display purposes.
CmdWithArgs string `json:"cmdWithArgs,omitempty"`
}
// String returns a display string for the command entry.
func (c CommandEntry) String() string {
if c.CmdWithArgs != "" {
return c.CmdWithArgs
}
if c.Command == "" {
return ""
}
if len(c.Args) == 0 {
return c.Command
}
return c.Command + " " + strings.Join(c.Args, " ")
}
// HasMultipleCommands returns true if the step has multiple commands to execute.
func (s *Step) HasMultipleCommands() bool {
return len(s.Commands) > 1
}
// UnmarshalJSON implements json.Unmarshaler for backward compatibility.
// It handles old JSON format where command/args fields were used instead of commands.
func (s *Step) UnmarshalJSON(data []byte) error {
// Use type alias to avoid infinite recursion
type Alias Step
aux := &struct {
*Alias
}{
Alias: (*Alias)(s),
}
if err := json.Unmarshal(data, aux); err != nil {
return err
}
// If Commands is already populated, we're done (new format)
if len(s.Commands) > 0 {
return nil
}
// Migrate legacy fields to Commands only when legacy command data exists.
if s.Command != "" || len(s.Args) > 0 || s.CmdWithArgs != "" {
s.Commands = []CommandEntry{{
Command: s.Command,
Args: s.Args,
CmdWithArgs: s.CmdWithArgs,
}}
}
return nil
}
// ExecutorConfig contains the configuration for the executor.
type ExecutorConfig struct {
// Type represents one of the registered executors.

View File

@ -276,3 +276,43 @@ func TestRepeatPolicy_MarshalUnmarshal(t *testing.T) {
assert.Equal(t, rp.RepeatMode, rp2.RepeatMode)
})
}
func TestCommandEntry_String(t *testing.T) {
tests := []struct {
name string
entry CommandEntry
expected string
}{
{
name: "Empty",
entry: CommandEntry{},
expected: "",
},
{
name: "CommandOnly",
entry: CommandEntry{Command: "echo"},
expected: "echo",
},
{
name: "CommandWithArgs",
entry: CommandEntry{Command: "echo", Args: []string{"hello", "world"}},
expected: "echo hello world",
},
{
name: "CmdWithArgsTakesPriority",
entry: CommandEntry{Command: "echo", Args: []string{"hello"}, CmdWithArgs: "echo 'hello world'"},
expected: "echo 'hello world'",
},
{
name: "CmdWithArgsOnly",
entry: CommandEntry{CmdWithArgs: "ls -la"},
expected: "ls -la",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.entry.String())
})
}
}

View File

@ -6,6 +6,28 @@ import (
"strings"
)
// DAGNameMaxLen defines the maximum allowed length for a DAG name.
const DAGNameMaxLen = 40
// dagNameRegex matches valid DAG names: alphanumeric, underscore, dash, dot.
var dagNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
// ValidateDAGName validates a DAG name according to shared rules.
// - Empty name is allowed (caller may provide one via context or filename).
// - Non-empty name must satisfy length and allowed character constraints.
func ValidateDAGName(name string) error {
if name == "" {
return nil
}
if len(name) > DAGNameMaxLen {
return ErrNameTooLong
}
if !dagNameRegex.MatchString(name) {
return ErrNameInvalidChars
}
return nil
}
// StepValidator is a function type for validating step configurations.
type StepValidator func(step Step) error

View File

@ -0,0 +1,774 @@
package core
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsValidStepID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
id string
expected bool
}{
// Valid cases - starts with letter, followed by alphanumeric/dash/underscore
{name: "single letter", id: "a", expected: true},
{name: "simple word", id: "step", expected: true},
{name: "word with number", id: "step1", expected: true},
{name: "word with dash", id: "my-step", expected: true},
{name: "word with underscore", id: "my_step", expected: true},
{name: "mixed case", id: "MyStep", expected: true},
{name: "uppercase", id: "STEP", expected: true},
{name: "complex valid id", id: "Step123-test_id", expected: true},
{name: "letters and numbers", id: "step123abc", expected: true},
{name: "uppercase with numbers", id: "STEP123", expected: true},
// Invalid cases
{name: "starts with number", id: "1step", expected: false},
{name: "starts with dash", id: "-step", expected: false},
{name: "starts with underscore", id: "_step", expected: false},
{name: "contains space", id: "step name", expected: false},
{name: "contains exclamation", id: "step!", expected: false},
{name: "contains at sign", id: "step@test", expected: false},
{name: "contains dot", id: "step.name", expected: false},
{name: "empty string", id: "", expected: false},
{name: "only numbers", id: "123", expected: false},
{name: "contains slash", id: "step/name", expected: false},
{name: "contains colon", id: "step:name", expected: false},
{name: "contains equals", id: "step=value", expected: false},
{name: "unicode characters", id: "step日本語", expected: false},
{name: "emoji", id: "step🚀", expected: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := isValidStepID(tt.id)
assert.Equal(t, tt.expected, result,
"isValidStepID(%q) = %v, want %v", tt.id, result, tt.expected)
})
}
}
func TestIsReservedWord(t *testing.T) {
t.Parallel()
// All reserved words (case insensitive)
reservedWords := []string{"env", "params", "args", "stdout", "stderr", "output", "outputs"}
t.Run("reserved words are detected", func(t *testing.T) {
t.Parallel()
for _, word := range reservedWords {
assert.True(t, isReservedWord(word),
"isReservedWord(%q) should return true", word)
}
})
t.Run("reserved words uppercase are detected", func(t *testing.T) {
t.Parallel()
for _, word := range reservedWords {
upper := strings.ToUpper(word)
assert.True(t, isReservedWord(upper),
"isReservedWord(%q) should return true", upper)
}
})
t.Run("reserved words mixed case are detected", func(t *testing.T) {
t.Parallel()
mixedCases := []string{"Env", "PARAMS", "Args", "StdOut", "StdErr", "Output", "Outputs"}
for _, word := range mixedCases {
assert.True(t, isReservedWord(word),
"isReservedWord(%q) should return true", word)
}
})
t.Run("non-reserved words are not detected", func(t *testing.T) {
t.Parallel()
nonReserved := []string{
"environment",
"parameter",
"arguments",
"step",
"run",
"execute",
"command",
"envs",
"param",
"arg",
"out",
"err",
"myenv",
"test-stdout",
}
for _, word := range nonReserved {
assert.False(t, isReservedWord(word),
"isReservedWord(%q) should return false", word)
}
})
t.Run("empty string is not reserved", func(t *testing.T) {
t.Parallel()
assert.False(t, isReservedWord(""))
})
}
func TestValidateSteps(t *testing.T) {
t.Parallel()
// Use a non-empty executor type to avoid triggering command validators
// that may be registered via init() from other packages
testExec := ExecutorConfig{Type: "test-no-validator"}
t.Run("valid DAG with steps passes validation", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ExecutorConfig: testExec},
{Name: "step2", Depends: []string{"step1"}, ExecutorConfig: testExec},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
t.Run("empty DAG passes validation", func(t *testing.T) {
t.Parallel()
dag := &DAG{Steps: []Step{}}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
// Pass 1: ID validation tests
t.Run("step with valid ID passes", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ID: "myStepId", ExecutorConfig: testExec},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
t.Run("step with invalid ID format fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ID: "1invalid"}, // starts with number
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid step ID format")
})
t.Run("step with reserved word ID fails", func(t *testing.T) {
t.Parallel()
reservedWords := []string{"env", "params", "args", "stdout", "stderr", "output", "outputs"}
for _, word := range reservedWords {
dag := &DAG{
Steps: []Step{
{Name: "step1", ID: word},
},
}
err := ValidateSteps(dag)
require.Error(t, err, "ID %q should be rejected as reserved", word)
assert.Contains(t, err.Error(), "reserved word")
}
})
t.Run("duplicate step names fail", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "duplicate"},
{Name: "duplicate"},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrStepNameDuplicate))
})
t.Run("duplicate step IDs fail", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ID: "sameId"},
{Name: "step2", ID: "sameId"},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.Contains(t, err.Error(), "duplicate step ID")
})
// Pass 2: Name/ID conflict tests
t.Run("step ID conflicts with another step name fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "conflictName", ExecutorConfig: testExec},
{Name: "step2", ID: "conflictName", ExecutorConfig: testExec},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
// The validator detects that step name "conflictName" conflicts with another step's ID
assert.Contains(t, err.Error(), "conflicts")
})
t.Run("step name conflicts with another step ID fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ID: "conflictId", ExecutorConfig: testExec},
{Name: "conflictId", ExecutorConfig: testExec},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
// The validator detects that step ID "conflictId" conflicts with another step's name
assert.Contains(t, err.Error(), "conflicts")
})
t.Run("same step has matching name and ID is allowed", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "sameName", ID: "sameName", ExecutorConfig: testExec},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
// Pass 3 & 4: Dependency tests
t.Run("valid dependencies pass", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ExecutorConfig: testExec},
{Name: "step2", Depends: []string{"step1"}, ExecutorConfig: testExec},
{Name: "step3", Depends: []string{"step1", "step2"}, ExecutorConfig: testExec},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
t.Run("non-existent dependency fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1"},
{Name: "step2", Depends: []string{"nonexistent"}},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.Contains(t, err.Error(), "non-existent step")
})
t.Run("ID reference in depends is resolved to name", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ID: "s1", ExecutorConfig: testExec},
{Name: "step2", Depends: []string{"s1"}, ExecutorConfig: testExec}, // references ID
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
// After validation, depends should be resolved to name
assert.Contains(t, dag.Steps[1].Depends, "step1")
})
// Pass 5: Step validation tests
t.Run("step with empty name fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: ""},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
// The internal error is "step name not generated"
assert.Contains(t, err.Error(), "step name")
})
t.Run("step name too long fails", func(t *testing.T) {
t.Parallel()
longName := strings.Repeat("a", 41) // 41 chars, max is 40
dag := &DAG{
Steps: []Step{
{Name: longName},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrStepNameTooLong))
})
t.Run("step name at max length passes", func(t *testing.T) {
t.Parallel()
maxName := strings.Repeat("a", 40) // exactly 40 chars
dag := &DAG{
Steps: []Step{
{Name: maxName, ExecutorConfig: testExec},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
// Parallel config validation
t.Run("parallel config without SubDAG fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{
Name: "step1",
Parallel: &ParallelConfig{
MaxConcurrent: 2,
Items: []ParallelItem{{Value: "a"}, {Value: "b"}},
},
// SubDAG is nil
},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.Contains(t, err.Error(), "only supported for child-DAGs")
})
t.Run("parallel config with maxConcurrent 0 fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{
Name: "step1",
Parallel: &ParallelConfig{
MaxConcurrent: 0,
Items: []ParallelItem{{Value: "a"}, {Value: "b"}},
},
SubDAG: &SubDAG{Name: "child"},
},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.Contains(t, err.Error(), "maxConcurrent must be greater than 0")
})
t.Run("parallel config with negative maxConcurrent fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{
Name: "step1",
Parallel: &ParallelConfig{
MaxConcurrent: -1,
Items: []ParallelItem{{Value: "a"}, {Value: "b"}},
},
SubDAG: &SubDAG{Name: "child"},
},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.Contains(t, err.Error(), "maxConcurrent must be greater than 0")
})
t.Run("parallel config without items or variable fails", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{
Name: "step1",
Parallel: &ParallelConfig{
MaxConcurrent: 2,
// no items, no variable
},
SubDAG: &SubDAG{Name: "child"},
},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.Contains(t, err.Error(), "must have either items array or variable reference")
})
t.Run("valid parallel config with items passes", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{
Name: "step1",
Parallel: &ParallelConfig{
MaxConcurrent: 2,
Items: []ParallelItem{{Value: "a"}, {Value: "b"}, {Value: "c"}},
},
SubDAG: &SubDAG{Name: "child"},
},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
t.Run("valid parallel config with variable passes", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{
Name: "step1",
Parallel: &ParallelConfig{
MaxConcurrent: 2,
Variable: "ITEMS",
},
SubDAG: &SubDAG{Name: "child"},
},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
}
func TestRegisterStepValidator(t *testing.T) {
// Note: These tests modify global state, so they should not run in parallel
// with each other. Each test should clean up after itself.
t.Run("register validator for new type", func(t *testing.T) {
// Clean up after test
defer delete(stepValidators, "test-executor")
validatorCalled := false
validator := func(_ Step) error {
validatorCalled = true
return nil
}
RegisterStepValidator("test-executor", validator)
// Create a DAG with a step using this executor type
dag := &DAG{
Steps: []Step{
{
Name: "step1",
ExecutorConfig: ExecutorConfig{Type: "test-executor"},
},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
assert.True(t, validatorCalled, "validator should have been called")
})
t.Run("validator returning error propagates", func(t *testing.T) {
defer delete(stepValidators, "error-executor")
expectedErr := errors.New("validation failed")
validator := func(_ Step) error {
return expectedErr
}
RegisterStepValidator("error-executor", validator)
dag := &DAG{
Steps: []Step{
{
Name: "step1",
ExecutorConfig: ExecutorConfig{Type: "error-executor"},
},
},
}
err := ValidateSteps(dag)
require.Error(t, err)
assert.Contains(t, err.Error(), "validation failed")
})
t.Run("overwrite existing validator", func(t *testing.T) {
defer delete(stepValidators, "overwrite-executor")
firstCalled := false
secondCalled := false
first := func(_ Step) error {
firstCalled = true
return nil
}
second := func(_ Step) error {
secondCalled = true
return nil
}
RegisterStepValidator("overwrite-executor", first)
RegisterStepValidator("overwrite-executor", second)
dag := &DAG{
Steps: []Step{
{
Name: "step1",
ExecutorConfig: ExecutorConfig{Type: "overwrite-executor"},
},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
assert.False(t, firstCalled, "first validator should not be called")
assert.True(t, secondCalled, "second validator should be called")
})
t.Run("no validator for type does not fail", func(t *testing.T) {
dag := &DAG{
Steps: []Step{
{
Name: "step1",
ExecutorConfig: ExecutorConfig{Type: "unregistered-executor"},
},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
}
func TestResolveStepDependencies(t *testing.T) {
t.Parallel()
t.Run("resolves ID references to names", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "firstStep", ID: "first"},
{Name: "secondStep", ID: "second"},
{Name: "thirdStep", Depends: []string{"first", "second"}},
},
}
err := resolveStepDependencies(dag)
require.NoError(t, err)
// Dependencies should be resolved to names
assert.Contains(t, dag.Steps[2].Depends, "firstStep")
assert.Contains(t, dag.Steps[2].Depends, "secondStep")
})
t.Run("preserves name references", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1"},
{Name: "step2", Depends: []string{"step1"}}, // uses name, not ID
},
}
err := resolveStepDependencies(dag)
require.NoError(t, err)
// Name reference should remain unchanged
assert.Contains(t, dag.Steps[1].Depends, "step1")
})
t.Run("mixed ID and name references", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ID: "s1"},
{Name: "step2"},
{Name: "step3", Depends: []string{"s1", "step2"}}, // mix of ID and name
},
}
err := resolveStepDependencies(dag)
require.NoError(t, err)
// ID should be resolved, name should remain
assert.Contains(t, dag.Steps[2].Depends, "step1") // s1 resolved to step1
assert.Contains(t, dag.Steps[2].Depends, "step2") // step2 unchanged
})
t.Run("empty DAG", func(t *testing.T) {
t.Parallel()
dag := &DAG{Steps: []Step{}}
err := resolveStepDependencies(dag)
assert.NoError(t, err)
})
t.Run("steps without dependencies", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
{Name: "step1", ID: "s1"},
{Name: "step2", ID: "s2"},
},
}
err := resolveStepDependencies(dag)
assert.NoError(t, err)
assert.Empty(t, dag.Steps[0].Depends)
assert.Empty(t, dag.Steps[1].Depends)
})
}
func TestValidateStep(t *testing.T) {
t.Parallel()
// Use a non-empty executor type to avoid triggering command validators
testExecutorType := "test-no-validator"
t.Run("valid step passes", func(t *testing.T) {
t.Parallel()
step := Step{Name: "validStep", ExecutorConfig: ExecutorConfig{Type: testExecutorType}}
err := validateStep(step)
assert.NoError(t, err)
})
t.Run("empty name fails", func(t *testing.T) {
t.Parallel()
step := Step{Name: "", ExecutorConfig: ExecutorConfig{Type: testExecutorType}}
err := validateStep(step)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrStepNameRequired))
})
t.Run("name too long fails", func(t *testing.T) {
t.Parallel()
step := Step{Name: strings.Repeat("x", 41), ExecutorConfig: ExecutorConfig{Type: testExecutorType}}
err := validateStep(step)
require.Error(t, err)
assert.True(t, errors.Is(err, ErrStepNameTooLong))
})
t.Run("name at exactly max length passes", func(t *testing.T) {
t.Parallel()
step := Step{Name: strings.Repeat("x", 40), ExecutorConfig: ExecutorConfig{Type: testExecutorType}}
err := validateStep(step)
assert.NoError(t, err)
})
}
func TestValidateStepWithValidator(t *testing.T) {
t.Run("no validator returns nil", func(t *testing.T) {
step := Step{
Name: "step1",
ExecutorConfig: ExecutorConfig{Type: "unknown-type"},
}
err := validateStepWithValidator(step)
assert.NoError(t, err)
})
t.Run("nil validator returns nil", func(t *testing.T) {
defer delete(stepValidators, "nil-validator-type")
stepValidators["nil-validator-type"] = nil
step := Step{
Name: "step1",
ExecutorConfig: ExecutorConfig{Type: "nil-validator-type"},
}
err := validateStepWithValidator(step)
assert.NoError(t, err)
})
t.Run("validator error is wrapped", func(t *testing.T) {
defer delete(stepValidators, "wrap-error-type")
customErr := errors.New("custom validation error")
stepValidators["wrap-error-type"] = func(_ Step) error {
return customErr
}
step := Step{
Name: "step1",
ExecutorConfig: ExecutorConfig{Type: "wrap-error-type"},
}
err := validateStepWithValidator(step)
require.Error(t, err)
// Should be wrapped in ValidationError
var ve *ValidationError
require.True(t, errors.As(err, &ve))
assert.Equal(t, "executorConfig", ve.Field)
assert.True(t, errors.Is(err, customErr))
})
}
func TestValidateSteps_ComplexScenarios(t *testing.T) {
t.Parallel()
// Use a non-empty executor type to avoid triggering command validators
// that may be registered via init() from other packages
testExecutorType := "test-no-validator"
t.Run("large DAG with many steps", func(t *testing.T) {
t.Parallel()
// Create a DAG with 100 steps in a chain
steps := make([]Step, 100)
for i := 0; i < 100; i++ {
steps[i] = Step{
Name: fmt.Sprintf("step%d", i),
ExecutorConfig: ExecutorConfig{Type: testExecutorType},
}
if i > 0 {
steps[i].Depends = []string{fmt.Sprintf("step%d", i-1)}
}
}
dag := &DAG{Steps: steps}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
t.Run("diamond dependency pattern", func(t *testing.T) {
t.Parallel()
// A
// / \
// B C
// \ /
// D
dag := &DAG{
Steps: []Step{
{Name: "A", ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
{Name: "B", Depends: []string{"A"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
{Name: "C", Depends: []string{"A"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
{Name: "D", Depends: []string{"B", "C"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
t.Run("multiple independent chains", func(t *testing.T) {
t.Parallel()
dag := &DAG{
Steps: []Step{
// Chain 1
{Name: "chain1-step1", ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
{Name: "chain1-step2", Depends: []string{"chain1-step1"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
// Chain 2
{Name: "chain2-step1", ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
{Name: "chain2-step2", Depends: []string{"chain2-step1"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
},
}
err := ValidateSteps(dag)
assert.NoError(t, err)
})
}

View File

@ -32,6 +32,8 @@ type dockerExecutorTest struct {
}
func TestDockerExecutor(t *testing.T) {
t.Parallel()
tests := []dockerExecutorTest{
{
name: "BasicExecution",
@ -90,6 +92,8 @@ type containerTest struct {
}
func TestDAGLevelContainer(t *testing.T) {
t.Parallel()
tests := []containerTest{
{
name: "VolumeBindMounts",
@ -237,8 +241,6 @@ steps:
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tempDir, err := os.MkdirTemp("", fmt.Sprintf("%s-%s-*", containerPrefix, tt.name))
require.NoError(t, err, "failed to create temporary directory")
t.Cleanup(func() { _ = os.RemoveAll(tempDir) })
@ -513,3 +515,233 @@ func waitForContainerStop(t *testing.T, th test.Helper, dockerClient *client.Cli
}
return false
}
// TestStepLevelContainer tests the new step-level container syntax
// which allows specifying a container field directly on a step instead of
// using the executor syntax.
func TestStepLevelContainer(t *testing.T) {
t.Parallel()
tests := []containerTest{
{
name: "BasicStepContainer",
dagConfigFunc: func(_ string) string {
return fmt.Sprintf(`
steps:
- name: run-in-container
container:
image: %s
command: echo "hello from step container"
output: STEP_CONTAINER_OUT
`, testImage)
},
expectedOutputs: map[string]any{
"STEP_CONTAINER_OUT": "hello from step container",
},
},
{
name: "StepContainerWithWorkingDir",
dagConfigFunc: func(_ string) string {
return fmt.Sprintf(`
steps:
- name: check-workdir
container:
image: %s
workingDir: /tmp
command: pwd
output: STEP_WORKDIR_OUT
`, testImage)
},
expectedOutputs: map[string]any{
"STEP_WORKDIR_OUT": "/tmp",
},
},
{
name: "StepContainerWithEnv",
dagConfigFunc: func(_ string) string {
return fmt.Sprintf(`
steps:
- name: check-env
container:
image: %s
env:
- MY_VAR=hello_world
command: sh -c "echo $MY_VAR"
output: STEP_ENV_OUT
`, testImage)
},
expectedOutputs: map[string]any{
"STEP_ENV_OUT": "hello_world",
},
},
{
name: "StepContainerWithVolume",
dagConfigFunc: func(tempDir string) string {
return fmt.Sprintf(`
steps:
- name: write-file
container:
image: %s
volumes:
- %s:/data
command: sh -c "echo 'step volume test' > /data/step_test.txt"
- name: read-file
container:
image: %s
volumes:
- %s:/data
command: cat /data/step_test.txt
output: STEP_VOL_OUT
depends:
- write-file
`, testImage, tempDir, testImage, tempDir)
},
expectedOutputs: map[string]any{
"STEP_VOL_OUT": "step volume test",
},
},
{
name: "MultipleStepsWithDifferentContainers",
dagConfigFunc: func(_ string) string {
return fmt.Sprintf(`
steps:
- name: alpine-step
container:
image: %s
command: cat /etc/alpine-release
output: ALPINE_VERSION
- name: busybox-step
container:
image: busybox:latest
command: echo "busybox step"
output: BUSYBOX_OUT
`, testImage)
},
expectedOutputs: map[string]any{
"BUSYBOX_OUT": "busybox step",
},
},
{
name: "StepContainerOverridesDAGContainer",
dagConfigFunc: func(_ string) string {
return fmt.Sprintf(`
# DAG-level container - steps without container field use this
container:
image: busybox:latest
steps:
- name: use-dag-container
command: echo "in DAG container"
output: DAG_CONTAINER_OUT
- name: use-step-container
container:
image: %s
command: cat /etc/alpine-release
output: STEP_CONTAINER_OUT
`, testImage)
},
expectedOutputs: map[string]any{
"DAG_CONTAINER_OUT": "in DAG container",
},
},
{
name: "StepContainerWithUser",
dagConfigFunc: func(_ string) string {
return fmt.Sprintf(`
steps:
- name: check-user
container:
image: %s
user: "nobody"
command: whoami
output: STEP_USER_OUT
`, testImage)
},
expectedOutputs: map[string]any{
"STEP_USER_OUT": "nobody",
},
},
{
name: "StepContainerWithPullPolicy",
dagConfigFunc: func(_ string) string {
return fmt.Sprintf(`
steps:
- name: pull-never
container:
image: %s
pullPolicy: never
command: echo "pull never ok"
output: PULL_NEVER_OUT
`, testImage)
},
expectedOutputs: map[string]any{
"PULL_NEVER_OUT": "pull never ok",
},
},
{
name: "StepEnvMergedIntoContainer",
dagConfigFunc: func(_ string) string {
// Test that step.env is merged with container.env
// container.env takes precedence for shared keys
// Use printenv to show actual environment in container
// Note: SEMIC_ prefix is an abbreviation of the test name (StepEnvMergedIntoContainerEnv)
// to avoid environment variable collisions between tests
return fmt.Sprintf(`
steps:
- name: check-merged-env
env:
- SEMIC_STEP_VAR=from_step
- SEMIC_SHARED_VAR=step_value
container:
image: %s
env:
- SEMIC_CONTAINER_VAR=from_container
- SEMIC_SHARED_VAR=container_value
command: printenv SEMIC_SHARED_VAR
output: SEMIC_MERGED_ENV_OUT
`, testImage)
},
expectedOutputs: map[string]any{
// SEMIC_SHARED_VAR should be container_value (container.env takes precedence)
"SEMIC_MERGED_ENV_OUT": "container_value",
},
},
{
name: "StepEnvOnlyPassedToContainer",
dagConfigFunc: func(_ string) string {
// Test that step.env is passed to container even without container.env
return fmt.Sprintf(`
steps:
- name: step-env-only
env:
- MY_STEP_VAR=hello_from_step
container:
image: %s
command: printenv MY_STEP_VAR
output: STEP_ENV_ONLY_OUT
`, testImage)
},
expectedOutputs: map[string]any{
"STEP_ENV_ONLY_OUT": "hello_from_step",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir, err := os.MkdirTemp("", fmt.Sprintf("%s-step-%s-*", containerPrefix, tt.name))
require.NoError(t, err, "failed to create temporary directory")
t.Cleanup(func() { _ = os.RemoveAll(tempDir) })
if tt.setupFunc != nil {
tt.setupFunc(t, tempDir)
}
th := test.Setup(t)
dag := th.DAG(t, tt.dagConfigFunc(tempDir))
dag.Agent().RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
dag.AssertOutputs(t, tt.expectedOutputs)
})
}
}

View File

@ -50,15 +50,15 @@ steps:
// Load the DAG
dagWrapper := coord.DAG(t, yamlContent)
// Build the start command spec
// Build the enqueue command spec
subCmdBuilder := runtime.NewSubCmdBuilder(coord.Config)
startSpec := subCmdBuilder.Start(dagWrapper.DAG, runtime.StartOptions{
enqueueSpec := subCmdBuilder.Enqueue(dagWrapper.DAG, runtime.EnqueueOptions{
Quiet: true,
})
// Execute the start command (spawns subprocess)
err := runtime.Start(coord.Context, startSpec)
require.NoError(t, err, "Start command should succeed")
// Execute the enqueue command (spawns subprocess)
err := runtime.Start(coord.Context, enqueueSpec)
require.NoError(t, err, "Enqueue command should succeed")
// Wait for the subprocess to complete enqueueing
require.Eventually(t, func() bool {
@ -138,10 +138,10 @@ steps:
t.Log("E2E test completed successfully!")
})
t.Run("E2E_StartCommand_WithNoQueueFlag_ShouldExecuteDirectly", func(t *testing.T) {
// Verify that --no-queue flag bypasses enqueueing even when workerSelector exists
t.Run("E2E_StartCommand_WorkerSelector_ShouldExecuteLocally", func(t *testing.T) {
// Verify that dagu start always executes locally even when workerSelector exists
yamlContent := `
name: no-queue-dag
name: local-start-dag
workerSelector:
test: value
steps:
@ -154,11 +154,10 @@ steps:
// Load the DAG
dagWrapper := coord.DAG(t, yamlContent)
// Build start command WITH --no-queue flag
// Build start command
subCmdBuilder := runtime.NewSubCmdBuilder(coord.Config)
startSpec := subCmdBuilder.Start(dagWrapper.DAG, runtime.StartOptions{
Quiet: true,
NoQueue: true, // This bypasses enqueueing
Quiet: true,
})
err := runtime.Start(ctx, startSpec)
@ -170,7 +169,7 @@ steps:
// Should NOT be enqueued (executed directly)
queueItems, err := coord.QueueStore.ListByDAGName(ctx, dagWrapper.ProcGroup(), dagWrapper.Name)
require.NoError(t, err)
require.Len(t, queueItems, 0, "DAG should NOT be enqueued when --no-queue is set")
require.Len(t, queueItems, 0, "DAG should NOT be enqueued (dagu start runs locally)")
})
t.Run("E2E_DistributedExecution_Cancellation_SubDAG", func(t *testing.T) {

View File

@ -5,20 +5,12 @@ import (
"time"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/execution"
"github.com/dagu-org/dagu/internal/runtime"
"github.com/dagu-org/dagu/internal/test"
"github.com/stretchr/testify/require"
)
// TestStartCommandWithWorkerSelector tests that the start command enqueues
// DAGs with workerSelector instead of executing them locally, and that the
// scheduler dispatches them to workers correctly.
//
// This is the integration test for the distributed execution fix where:
// 1. start command checks for workerSelector → enqueues (instead of executing)
// 2. Scheduler queue handler picks it up → dispatches to coordinator
// 3. Worker executes with --no-queue flag → executes directly (no re-enqueue)
// 3. Worker executes directly as ordered by the scheduler (no re-enqueue)
func TestStartCommandWithWorkerSelector(t *testing.T) {
t.Run("StartCommand_WithWorkerSelector_ShouldEnqueue", func(t *testing.T) {
// This test verifies that when a DAG has workerSelector,
@ -53,16 +45,16 @@ steps:
err := runtime.Start(coord.Context, startSpec)
require.NoError(t, err, "Start command should succeed")
// Wait for the DAG to be enqueued
// Wait for completion (executed locally)
require.Eventually(t, func() bool {
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
return err == nil && len(queueItems) == 1
}, 2*time.Second, 50*time.Millisecond, "DAG should be enqueued")
status, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
return err == nil && status.Status == core.Succeeded
}, 2*time.Second, 50*time.Millisecond, "DAG should complete successfully")
// Verify the DAG was enqueued (not executed locally)
// Verify the DAG was NOT enqueued (executed locally)
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
require.NoError(t, err)
require.Len(t, queueItems, 1, "DAG should be enqueued once")
require.Len(t, queueItems, 0, "DAG should NOT be enqueued (dagu start runs locally)")
if len(queueItems) > 0 {
data, err := queueItems[0].Data()
@ -70,18 +62,18 @@ steps:
t.Logf("DAG enqueued: dag=%s runId=%s", data.Name, data.ID)
}
// Verify the DAG status is "queued" (not started/running)
// Verify the DAG status is "succeeded" (executed locally)
latest, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
require.NoError(t, err)
require.Equal(t, core.Queued, latest.Status, "DAG status should be queued")
require.Equal(t, core.Succeeded, latest.Status, "DAG status should be succeeded")
})
t.Run("StartCommand_WithNoQueueFlag_ShouldExecuteDirectly", func(t *testing.T) {
// Verify that --no-queue flag bypasses enqueueing
t.Run("StartCommand_WorkerSelector_ShouldExecuteLocally", func(t *testing.T) {
// Verify that dagu start always executes locally
// even when workerSelector exists
yamlContent := `
name: no-queue-dag
name: local-start-dag
workerSelector:
test: value
steps:
@ -94,11 +86,10 @@ steps:
// Load the DAG
dagWrapper := coord.DAG(t, yamlContent)
// Build start command WITH --no-queue flag
// Build start command
subCmdBuilder := runtime.NewSubCmdBuilder(coord.Config)
startSpec := subCmdBuilder.Start(dagWrapper.DAG, runtime.StartOptions{
Quiet: true,
NoQueue: true, // This bypasses enqueueing
Quiet: true,
})
err := runtime.Start(ctx, startSpec)
@ -107,7 +98,7 @@ steps:
// Should NOT be enqueued (executed directly)
queueItems, err := coord.QueueStore.ListByDAGName(ctx, dagWrapper.ProcGroup(), dagWrapper.Name)
require.NoError(t, err)
require.Len(t, queueItems, 0, "DAG should NOT be enqueued when --no-queue is set")
require.Len(t, queueItems, 0, "DAG should NOT be enqueued (dagu start runs locally)")
// Verify it succeeded (executed locally)
dagWrapper.AssertLatestStatus(t, core.Succeeded)
@ -140,54 +131,39 @@ steps:
Quiet: true,
})
// Execute the start command (runs locally now)
err := runtime.Start(coord.Context, startSpec)
require.NoError(t, err, "Start command should succeed")
require.NoError(t, err, "Start command should succeed (process started)")
// Wait for the DAG to be enqueued
// Wait for completion (executed locally)
require.Eventually(t, func() bool {
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
return err == nil && len(queueItems) == 1
}, 2*time.Second, 50*time.Millisecond, "DAG should be enqueued")
status, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
return err == nil && status.Status == core.Failed
}, 5*time.Second, 100*time.Millisecond, "DAG should fail")
// Verify the DAG was enqueued
// Should NOT be enqueued
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
require.NoError(t, err)
require.Len(t, queueItems, 1, "DAG should be enqueued once")
require.Len(t, queueItems, 0, "DAG should NOT be enqueued (dagu start runs locally)")
var dagRunID string
var dagRun execution.DAGRunRef
if len(queueItems) > 0 {
data, err := queueItems[0].Data()
require.NoError(t, err, "Should be able to get queue item data")
dagRunID = data.ID
dagRun = *data
t.Logf("DAG enqueued: dag=%s runId=%s", data.Name, data.ID)
}
// Dequeue it to simulate processing
_, err = coord.QueueStore.DequeueByDAGRunID(coord.Context, dagWrapper.ProcGroup(), dagRun)
status, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
require.NoError(t, err)
dagRunID := status.DAGRunID
t.Logf("DAG failed: dag=%s runId=%s", status.Name, status.DAGRunID)
// Now retry the DAG - it should be enqueued again
retrySpec := subCmdBuilder.Retry(dagWrapper.DAG, dagRunID, "", false)
err = runtime.Run(coord.Context, retrySpec)
require.NoError(t, err, "Retry command should succeed")
// Now retry the DAG - it should run locally
retrySpec := subCmdBuilder.Retry(dagWrapper.DAG, dagRunID, "")
err = runtime.Start(coord.Context, retrySpec)
require.NoError(t, err, "Retry command should succeed (process started)")
// Wait for the retry to be enqueued
// Wait for completion
require.Eventually(t, func() bool {
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
return err == nil && len(queueItems) == 1
}, 2*time.Second, 50*time.Millisecond, "Retry should be enqueued")
status, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
return err == nil && status.Status == core.Failed
}, 5*time.Second, 100*time.Millisecond, "Retry should fail")
// Verify the retry was enqueued
// Should NOT be enqueued
queueItems, err = coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
require.NoError(t, err)
require.Len(t, queueItems, 1, "Retry should be enqueued once")
if len(queueItems) > 0 {
data, err := queueItems[0].Data()
require.NoError(t, err, "Should be able to get queue item data")
require.Equal(t, dagRunID, data.ID, "Should have same DAG run ID")
t.Logf("Retry enqueued: dag=%s runId=%s", data.Name, data.ID)
}
require.Len(t, queueItems, 0, "Retry should NOT be enqueued (dagu retry runs locally)")
}

View File

@ -1,42 +1,46 @@
package integration_test
import (
"os/exec"
"testing"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/test"
"github.com/stretchr/testify/require"
)
func TestGitHubActionsExecutor(t *testing.T) {
t.Parallel()
t.Skip("skip")
t.Run("BasicExecution", func(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
th := test.Setup(t)
dag := th.DAG(t, `steps:
dag := th.DAG(t, `
workingDir: `+tmpDir+`
steps:
- name: test-action
command: actions/hello-world-javascript-action@main
command: actions/checkout@v4
executor:
type: github_action
config:
runner: node:24-bookworm
runner: node:25-bookworm
params:
who-to-greet: "Morning"
output: ACTION_OUTPUT
repository: dagu-org/dagu
sparse-checkout: README.md
`)
// Verify git is available
_, err := exec.LookPath("git")
require.NoError(t, err, "git is required for this test but not found in PATH")
// Initialize git repo in the temp dir to satisfy act requirements
cmd := exec.Command("git", "init", dag.WorkingDir)
require.NoError(t, cmd.Run(), "failed to init git repo")
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Verify that container output was captured to stdout
// The hello-world action should log "Hello, Morning!" to console
dag.AssertOutputs(t, map[string]any{
"ACTION_OUTPUT": []test.Contains{
"Hello, Morning!",
},
})
})
}

View File

@ -2,11 +2,13 @@ package integration_test
import (
"context"
"strings"
"testing"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/execution"
"github.com/dagu-org/dagu/internal/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -255,3 +257,405 @@ steps:
require.NotNil(t, status.OnCancel, "abort handler should have been executed")
require.Equal(t, core.NodeSucceeded, status.OnCancel.Status)
}
// TestHandlerOn_EnvironmentVariables tests that special environment variables
// are accessible from each handler type.
//
// Environment variables availability by handler:
//
// | Variable | Init | Success | Failure | Cancel | Exit |
// |----------------------------|--------|-----------|------------|---------|-----------|
// | DAG_NAME | ✓ | ✓ | ✓ | ✓ | ✓ |
// | DAG_RUN_ID | ✓ | ✓ | ✓ | ✓ | ✓ |
// | DAG_RUN_LOG_FILE | ✓ | ✓ | ✓ | ✓ | ✓ |
// | DAG_RUN_STEP_NAME | onInit | onSuccess | onFailure | onCancel| onExit |
// | DAG_RUN_STATUS | running| succeeded | failed | aborted | succeeded/failed |
// | DAG_RUN_STEP_STDOUT_FILE | ✗ | ✗ | ✗ | ✗ | ✗ |
// | DAG_RUN_STEP_STDERR_FILE | ✗ | ✗ | ✗ | ✗ | ✗ |
//
// Note: DAG_RUN_STATUS in init handler is "running" because the DAG run has started
// but steps haven't executed yet. DAG_RUN_STEP_STDOUT_FILE and DAG_RUN_STEP_STDERR_FILE
// are not set for handlers because node.SetupEnv is not called during handler execution.
func TestHandlerOn_EnvironmentVariables(t *testing.T) {
t.Parallel()
// Helper to extract value from "KEY=value" format
extractValue := func(output string) string {
if idx := strings.Index(output, "="); idx != -1 {
return output[idx+1:]
}
return output
}
t.Run("InitHandler_BaseEnvVars", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// Test that basic env vars (DAG_NAME, DAG_RUN_ID, DAG_RUN_LOG_FILE, DAG_RUN_STEP_NAME)
// are available in the init handler
dag := th.DAG(t, `
handlerOn:
init:
command: |
echo "name:${DAG_NAME}|runid:${DAG_RUN_ID}|logfile:${DAG_RUN_LOG_FILE}|stepname:${DAG_RUN_STEP_NAME}"
output: INIT_ENV_OUTPUT
steps:
- name: step1
command: "true"
`)
agent := dag.Agent()
agent.RunSuccess(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnInit, "init handler should have been executed")
require.Equal(t, core.NodeSucceeded, status.OnInit.Status)
require.NotNil(t, status.OnInit.OutputVariables, "init handler should have output variables")
output, ok := status.OnInit.OutputVariables.Load("INIT_ENV_OUTPUT")
require.True(t, ok, "INIT_ENV_OUTPUT should be set")
outputStr := extractValue(output.(string))
// Verify DAG_NAME is set and non-empty
assert.Contains(t, outputStr, "name:", "output should contain name prefix")
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
// Verify DAG_RUN_ID is set (UUID format)
assert.Contains(t, outputStr, "runid:", "output should contain runid prefix")
assert.NotContains(t, outputStr, "runid:|", "DAG_RUN_ID should not be empty")
// Verify DAG_RUN_LOG_FILE is set and contains .log
assert.Contains(t, outputStr, "logfile:", "output should contain logfile prefix")
assert.Contains(t, outputStr, ".log", "DAG_RUN_LOG_FILE should contain .log")
// Verify DAG_RUN_STEP_NAME is set to "onInit"
assert.Contains(t, outputStr, "stepname:onInit", "DAG_RUN_STEP_NAME should be 'onInit'")
})
t.Run("InitHandler_DAGRunStatus_IsRunning", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// DAG_RUN_STATUS in init handler is "running" because the DAG run has started
// but steps haven't completed yet. This value is technically correct but not
// as useful as having the final status (which isn't known yet).
dag := th.DAG(t, `
handlerOn:
init:
command: echo "${DAG_RUN_STATUS}"
output: INIT_STATUS
steps:
- name: step1
command: "true"
`)
agent := dag.Agent()
agent.RunSuccess(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnInit)
require.NotNil(t, status.OnInit.OutputVariables)
output, ok := status.OnInit.OutputVariables.Load("INIT_STATUS")
require.True(t, ok)
// DAG_RUN_STATUS is "running" for init handler because steps haven't completed
outputStr := extractValue(output.(string))
assert.Equal(t, "running", outputStr, "DAG_RUN_STATUS should be 'running' in init handler")
})
t.Run("SuccessHandler_AllEnvVars", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, `
handlerOn:
success:
command: |
echo "name:${DAG_NAME}|status:${DAG_RUN_STATUS}|stepname:${DAG_RUN_STEP_NAME}"
output: SUCCESS_ENV_OUTPUT
steps:
- name: step1
command: "true"
`)
agent := dag.Agent()
agent.RunSuccess(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnSuccess, "success handler should have been executed")
require.Equal(t, core.NodeSucceeded, status.OnSuccess.Status)
require.NotNil(t, status.OnSuccess.OutputVariables)
output, ok := status.OnSuccess.OutputVariables.Load("SUCCESS_ENV_OUTPUT")
require.True(t, ok)
outputStr := extractValue(output.(string))
// Verify DAG_NAME is set
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
// Verify DAG_RUN_STATUS is "succeeded"
assert.Contains(t, outputStr, "status:succeeded", "DAG_RUN_STATUS should be 'succeeded'")
// Verify DAG_RUN_STEP_NAME is "onSuccess"
assert.Contains(t, outputStr, "stepname:onSuccess", "DAG_RUN_STEP_NAME should be 'onSuccess'")
})
t.Run("FailureHandler_AllEnvVars", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, `
handlerOn:
failure:
command: |
echo "name:${DAG_NAME}|status:${DAG_RUN_STATUS}|stepname:${DAG_RUN_STEP_NAME}"
output: FAILURE_ENV_OUTPUT
steps:
- name: failing-step
command: exit 1
`)
agent := dag.Agent()
agent.RunError(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnFailure, "failure handler should have been executed")
require.Equal(t, core.NodeSucceeded, status.OnFailure.Status)
require.NotNil(t, status.OnFailure.OutputVariables)
output, ok := status.OnFailure.OutputVariables.Load("FAILURE_ENV_OUTPUT")
require.True(t, ok)
outputStr := extractValue(output.(string))
// Verify DAG_NAME is set
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
// Verify DAG_RUN_STATUS is "failed"
assert.Contains(t, outputStr, "status:failed", "DAG_RUN_STATUS should be 'failed'")
// Verify DAG_RUN_STEP_NAME is "onFailure"
assert.Contains(t, outputStr, "stepname:onFailure", "DAG_RUN_STEP_NAME should be 'onFailure'")
})
t.Run("ExitHandler_AllEnvVars_OnSuccess", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, `
handlerOn:
exit:
command: |
echo "name:${DAG_NAME}|status:${DAG_RUN_STATUS}|stepname:${DAG_RUN_STEP_NAME}"
output: EXIT_ENV_OUTPUT
steps:
- name: step1
command: "true"
`)
agent := dag.Agent()
agent.RunSuccess(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnExit, "exit handler should have been executed")
require.Equal(t, core.NodeSucceeded, status.OnExit.Status)
require.NotNil(t, status.OnExit.OutputVariables)
output, ok := status.OnExit.OutputVariables.Load("EXIT_ENV_OUTPUT")
require.True(t, ok)
outputStr := extractValue(output.(string))
// Verify DAG_NAME is set
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
// Verify DAG_RUN_STATUS is "succeeded" (exit runs after success)
assert.Contains(t, outputStr, "status:succeeded", "DAG_RUN_STATUS should be 'succeeded'")
// Verify DAG_RUN_STEP_NAME is "onExit"
assert.Contains(t, outputStr, "stepname:onExit", "DAG_RUN_STEP_NAME should be 'onExit'")
})
t.Run("ExitHandler_AllEnvVars_OnFailure", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, `
handlerOn:
exit:
command: |
echo "status:${DAG_RUN_STATUS}"
output: EXIT_ENV_OUTPUT
steps:
- name: failing-step
command: exit 1
`)
agent := dag.Agent()
agent.RunError(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnExit, "exit handler should have been executed")
require.Equal(t, core.NodeSucceeded, status.OnExit.Status)
require.NotNil(t, status.OnExit.OutputVariables)
output, ok := status.OnExit.OutputVariables.Load("EXIT_ENV_OUTPUT")
require.True(t, ok)
outputStr := extractValue(output.(string))
// Verify DAG_RUN_STATUS is "failed" (exit runs after failure)
assert.Contains(t, outputStr, "status:failed", "DAG_RUN_STATUS should be 'failed'")
})
t.Run("CancelHandler_AllEnvVars", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, `
handlerOn:
abort:
command: |
echo "name:${DAG_NAME}|status:${DAG_RUN_STATUS}|stepname:${DAG_RUN_STEP_NAME}"
output: CANCEL_ENV_OUTPUT
steps:
- name: long-running
command: sleep 10
`)
dagAgent := dag.Agent()
done := make(chan struct{})
go func() {
_ = dagAgent.Run(th.Context)
close(done)
}()
dag.AssertLatestStatus(t, core.Running)
dagAgent.Abort()
<-done
status := dagAgent.Status(th.Context)
require.NotNil(t, status.OnCancel, "cancel handler should have been executed")
require.Equal(t, core.NodeSucceeded, status.OnCancel.Status)
require.NotNil(t, status.OnCancel.OutputVariables)
output, ok := status.OnCancel.OutputVariables.Load("CANCEL_ENV_OUTPUT")
require.True(t, ok)
outputStr := extractValue(output.(string))
// Verify DAG_NAME is set
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
// Verify DAG_RUN_STATUS is "aborted"
assert.Contains(t, outputStr, "status:aborted", "DAG_RUN_STATUS should be 'aborted'")
// Verify DAG_RUN_STEP_NAME is "onCancel"
assert.Contains(t, outputStr, "stepname:onCancel", "DAG_RUN_STEP_NAME should be 'onCancel'")
})
t.Run("StepOutputVars_NotAvailableInHandlers", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// DAG_RUN_STEP_STDOUT_FILE and DAG_RUN_STEP_STDERR_FILE are NOT set
// for handlers because node.SetupEnv is not called for handler execution
dag := th.DAG(t, `
handlerOn:
success:
command: |
echo "stdout:${DAG_RUN_STEP_STDOUT_FILE:-UNSET}|stderr:${DAG_RUN_STEP_STDERR_FILE:-UNSET}"
output: HANDLER_STEP_FILES
steps:
- name: step1
command: "true"
`)
agent := dag.Agent()
agent.RunSuccess(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnSuccess)
require.NotNil(t, status.OnSuccess.OutputVariables)
output, ok := status.OnSuccess.OutputVariables.Load("HANDLER_STEP_FILES")
require.True(t, ok)
outputStr := extractValue(output.(string))
// These should be unset/empty in handlers
assert.Contains(t, outputStr, "stdout:UNSET", "DAG_RUN_STEP_STDOUT_FILE should not be set in handler")
assert.Contains(t, outputStr, "stderr:UNSET", "DAG_RUN_STEP_STDERR_FILE should not be set in handler")
})
t.Run("Handlers_CanAccessStepOutputVariables", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// Handlers can access output variables from steps that have completed
dag := th.DAG(t, `
handlerOn:
success:
command: |
echo "step_output:${STEP_OUTPUT}"
output: SUCCESS_WITH_STEP_OUTPUT
steps:
- name: producer
command: echo "produced_value"
output: STEP_OUTPUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnSuccess)
require.NotNil(t, status.OnSuccess.OutputVariables)
output, ok := status.OnSuccess.OutputVariables.Load("SUCCESS_WITH_STEP_OUTPUT")
require.True(t, ok)
outputStr := extractValue(output.(string))
// Success handler should be able to access step output
assert.Contains(t, outputStr, "step_output:produced_value", "handler should access step output")
})
t.Run("InitHandler_CannotAccessStepOutputVariables", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// Init handler runs BEFORE steps, so it cannot access step outputs
dag := th.DAG(t, `
handlerOn:
init:
command: |
echo "step_output:${STEP_OUTPUT:-NOT_YET_AVAILABLE}"
output: INIT_STEP_ACCESS
steps:
- name: producer
command: echo "produced_value"
output: STEP_OUTPUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
status := agent.Status(th.Context)
require.NotNil(t, status.OnInit)
require.NotNil(t, status.OnInit.OutputVariables)
output, ok := status.OnInit.OutputVariables.Load("INIT_STEP_ACCESS")
require.True(t, ok)
outputStr := extractValue(output.(string))
// Init handler cannot access step output (steps haven't run yet)
assert.Contains(t, outputStr, "step_output:NOT_YET_AVAILABLE",
"init handler should not access step outputs")
})
}

View File

@ -1,463 +0,0 @@
package integration_test
import (
"os"
"testing"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/test"
"github.com/stretchr/testify/require"
)
func TestLocalDAGExecution(t *testing.T) {
t.Run("SimpleLocalDAG", func(t *testing.T) {
// Create a DAG with local sub DAGs using separator
yamlContent := `
steps:
- name: run-local-child
call: local-child
params: "NAME=World"
output: SUB_RESULT
- echo "Child said ${SUB_RESULT.outputs.GREETING}"
---
name: local-child
params:
- NAME
steps:
- command: echo "Hello, ${NAME}!"
output: GREETING
- echo "Greeting was ${GREETING}"
`
// Setup test helper
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
// Run the DAG
agent := testDAG.Agent()
require.NoError(t, agent.Run(agent.Context))
// Verify successful completion
testDAG.AssertLatestStatus(t, core.Succeeded)
// Get the full run status
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
require.NoError(t, err)
// Verify the first step (run-local-child) completed successfully
// Note: The sub DAG's output is not directly visible in the parent's stdout
require.Len(t, dagRunStatus.Nodes, 2)
require.Equal(t, "run-local-child", dagRunStatus.Nodes[0].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
// Verify the second step output
logContent, err := os.ReadFile(dagRunStatus.Nodes[1].Stdout)
require.NoError(t, err)
require.Contains(t, string(logContent), "Child said Hello, World!")
})
t.Run("ParallelLocalDAGExecution", func(t *testing.T) {
// Create a DAG with parallel execution of local DAGs
yamlContent := `
steps:
- name: parallel-tasks
call: worker-dag
parallel:
items:
- TASK_ID=1 TASK_NAME=alpha
- TASK_ID=2 TASK_NAME=beta
- TASK_ID=3 TASK_NAME=gamma
maxConcurrent: 2
---
name: worker-dag
params:
- TASK_ID
- TASK_NAME
steps:
- echo "Starting task ${TASK_ID} - ${TASK_NAME}"
- echo "Processing ${TASK_NAME} with ID ${TASK_ID}"
- echo "Completed ${TASK_NAME}"
`
// Setup test helper
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
// Run the DAG
agent := testDAG.Agent()
require.NoError(t, agent.Run(agent.Context))
// Verify successful completion
testDAG.AssertLatestStatus(t, core.Succeeded)
// Get the full run status
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
require.NoError(t, err)
// For parallel execution, we should have one step that ran multiple instances
require.Len(t, dagRunStatus.Nodes, 1)
require.Equal(t, "parallel-tasks", dagRunStatus.Nodes[0].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
})
t.Run("NestedLocalDAGs", func(t *testing.T) {
// Test that nested local DAGs beyond 1 level are not supported
// This should fail because middle-dag tries to run leaf-dag, but leaf-dag
// is not visible to middle-dag (only to root-dag)
yamlContent := `
steps:
- name: run-middle-dag
call: middle-dag
params: "ROOT_PARAM=FromRoot"
---
name: middle-dag
params:
- ROOT_PARAM
steps:
- command: echo "Received ${ROOT_PARAM}"
output: MIDDLE_OUTPUT
- name: run-leaf-dag
call: leaf-dag
params: "MIDDLE_PARAM=${MIDDLE_OUTPUT} LEAF_PARAM=FromMiddle"
---
name: leaf-dag
params:
- MIDDLE_PARAM
- LEAF_PARAM
steps:
- command: |
echo "Middle: ${MIDDLE_PARAM}, Leaf: ${LEAF_PARAM}"
`
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
agent := testDAG.Agent()
err := agent.Run(agent.Context)
// The root DAG execution will fail because middle-dag fails
require.Error(t, err)
require.Contains(t, err.Error(), "failed")
// This should fail because middle-dag cannot see leaf-dag
testDAG.AssertLatestStatus(t, core.Failed)
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
require.NoError(t, err)
// Root DAG should have one step that tried to run middle-dag
require.Len(t, dagRunStatus.Nodes, 1)
require.Equal(t, "run-middle-dag", dagRunStatus.Nodes[0].Step.Name)
require.Equal(t, core.NodeFailed, dagRunStatus.Nodes[0].Status)
})
t.Run("LocalDAGWithConditionalExecution", func(t *testing.T) {
// Test conditional execution with local DAGs
yamlContent := `
env:
- ENVIRONMENT: production
steps:
- name: check-env
command: echo "${ENVIRONMENT}"
output: ENV_TYPE
- name: run-prod-dag
call: production-dag
preconditions:
- condition: "${ENV_TYPE}"
expected: "production"
- name: run-dev-dag
call: development-dag
preconditions:
- condition: "${ENV_TYPE}"
expected: "development"
---
name: production-dag
steps:
- echo "Deploying to production"
- echo "Verifying production deployment"
---
name: development-dag
steps:
- echo "Building for development"
- echo "Running development tests"
`
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
agent := testDAG.Agent()
require.NoError(t, agent.Run(agent.Context))
testDAG.AssertLatestStatus(t, core.Succeeded)
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
require.NoError(t, err)
// Should have 3 steps: check-env, run-prod-dag, run-dev-dag
require.Len(t, dagRunStatus.Nodes, 3)
// Check environment step
require.Equal(t, "check-env", dagRunStatus.Nodes[0].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
// Production DAG should run
require.Equal(t, "run-prod-dag", dagRunStatus.Nodes[1].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[1].Status)
// Development DAG should be skipped
require.Equal(t, "run-dev-dag", dagRunStatus.Nodes[2].Step.Name)
require.Equal(t, core.NodeSkipped, dagRunStatus.Nodes[2].Status)
})
t.Run("LocalDAGWithOutputPassing", func(t *testing.T) {
// Test passing outputs between local DAGs
yamlContent := `
steps:
- name: generate-data
call: generator-dag
output: GEN_OUTPUT
- name: process-data
call: processor-dag
params: "INPUT_DATA=${GEN_OUTPUT.outputs.DATA}"
---
name: generator-dag
steps:
- command: echo "test-value-42"
output: DATA
---
name: processor-dag
params:
- INPUT_DATA
steps:
- command: echo "Processing ${INPUT_DATA}"
output: RESULT
- command: |
echo "Validated: ${RESULT}"
`
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
agent := testDAG.Agent()
require.NoError(t, agent.Run(agent.Context))
testDAG.AssertLatestStatus(t, core.Succeeded)
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
require.NoError(t, err)
// Should have 2 steps
require.Len(t, dagRunStatus.Nodes, 2)
// First step generates data
require.Equal(t, "generate-data", dagRunStatus.Nodes[0].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
// Second step processes data
require.Equal(t, "process-data", dagRunStatus.Nodes[1].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[1].Status)
})
t.Run("LocalDAGReferencesNonExistent", func(t *testing.T) {
// Test error when referencing non-existent local DAG
yamlContent := `
steps:
- name: run-missing-dag
call: non-existent-dag
---
name: some-other-dag
steps:
- echo "test"
`
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
agent := testDAG.Agent()
err := agent.Run(agent.Context)
// The agent will return an error when a step fails
require.Error(t, err)
require.Contains(t, err.Error(), "non-existent-dag")
// Check that the DAG failed
testDAG.AssertLatestStatus(t, core.Failed)
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
require.NoError(t, err)
// Should have one step that failed
require.Len(t, dagRunStatus.Nodes, 1)
require.Equal(t, "run-missing-dag", dagRunStatus.Nodes[0].Step.Name)
require.Equal(t, core.NodeFailed, dagRunStatus.Nodes[0].Status)
})
t.Run("LocalDAGWithComplexDependencies", func(t *testing.T) {
// Test complex dependencies between local DAGs
yamlContent := `
steps:
- name: setup
command: echo "Setting up"
output: SETUP_STATUS
- name: task1
call: task-dag
params: "TASK_NAME=Task1 SETUP=${SETUP_STATUS}"
output: TASK1_RESULT
- name: task2
call: task-dag
params: "TASK_NAME=Task2 SETUP=${SETUP_STATUS}"
output: TASK2_RESULT
- name: combine
command: |
echo "Combining ${TASK1_RESULT.outputs.RESULT} and ${TASK2_RESULT.outputs.RESULT}"
depends:
- task1
- task2
---
name: task-dag
params:
- TASK_NAME
- SETUP
steps:
- command: echo "${TASK_NAME} processing with ${SETUP}"
output: RESULT
`
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
agent := testDAG.Agent()
require.NoError(t, agent.Run(agent.Context))
testDAG.AssertLatestStatus(t, core.Succeeded)
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
require.NoError(t, err)
// Should have 4 steps: setup, task1, task2, combine
require.Len(t, dagRunStatus.Nodes, 4)
// Verify each step
require.Equal(t, "setup", dagRunStatus.Nodes[0].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
require.Equal(t, "task1", dagRunStatus.Nodes[1].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[1].Status)
require.Equal(t, "task2", dagRunStatus.Nodes[2].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[2].Status)
require.Equal(t, "combine", dagRunStatus.Nodes[3].Step.Name)
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[3].Status)
// Verify the combine step output
logContent, err := os.ReadFile(dagRunStatus.Nodes[3].Stdout)
require.NoError(t, err)
require.Contains(t, string(logContent), "Combining")
require.Contains(t, string(logContent), "Task1 processing with Setting up")
require.Contains(t, string(logContent), "Task2 processing with Setting up")
})
t.Run("PartialSuccessParallel", func(t *testing.T) {
// Create a DAG with parallel execution of local DAGs
yamlContent := `
steps:
- name: parallel-tasks
call: worker-dag
parallel:
items:
- TASK_ID=1 TASK_NAME=alpha
---
name: worker-dag
params:
- TASK_ID
- TASK_NAME
steps:
- command: exit 1
continueOn:
failure: true
- exit 0
`
// Setup test helper
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
// Run the DAG
agent := testDAG.Agent()
require.NoError(t, agent.Run(agent.Context))
// Verify successful completion
testDAG.AssertLatestStatus(t, core.PartiallySucceeded)
})
t.Run("PartialSuccessSubDAG", func(t *testing.T) {
// Create a DAG with parallel execution of local DAGs
yamlContent := `
steps:
- name: parallel-tasks
call: worker-dag
---
name: worker-dag
params:
- TASK_ID
- TASK_NAME
steps:
- command: exit 1
continueOn:
failure: true
- exit 0
`
// Setup test helper
th := test.Setup(t)
// Load the DAG using helper
testDAG := th.DAG(t, yamlContent)
// Run the DAG
agent := testDAG.Agent()
require.NoError(t, agent.Run(agent.Context))
// Verify successful completion
testDAG.AssertLatestStatus(t, core.PartiallySucceeded)
})
}

View File

@ -0,0 +1,468 @@
package integration_test
import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/spec"
"github.com/dagu-org/dagu/internal/test"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestMultipleCommands_Shell(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping Unix shell tests on Windows")
}
t.Parallel()
th := test.Setup(t)
t.Run("TwoCommands", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
steps:
- name: multi-cmd
command:
- echo hello
- echo world
output: OUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "hello\nworld",
})
})
t.Run("ThreeCommandsWithArgs", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
steps:
- name: multi-cmd
command:
- echo "first command"
- echo "second command"
- echo "third command"
output: OUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "first command\nsecond command\nthird command",
})
})
t.Run("CommandsWithEnvVars", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
env:
- MY_VAR: hello
steps:
- name: multi-cmd
command:
- echo $MY_VAR
- echo "${MY_VAR} world"
output: OUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "hello\nhello world",
})
})
t.Run("FirstCommandFails", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
steps:
- name: multi-cmd
command:
- "false"
- echo "should not run"
output: OUT
`)
agent := dag.Agent()
agent.RunError(t)
dag.AssertLatestStatus(t, core.Failed)
})
t.Run("SecondCommandFails", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
steps:
- name: multi-cmd
command:
- echo "first runs"
- "false"
- echo "should not run"
output: OUT
`)
agent := dag.Agent()
agent.RunError(t)
dag.AssertLatestStatus(t, core.Failed)
})
t.Run("CommandsWithPipes", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
shell: /bin/bash
steps:
- name: multi-cmd
command:
- echo "hello world" | tr 'h' 'H'
- echo "foo bar" | tr 'f' 'F'
output: OUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "Hello world\nFoo bar",
})
})
t.Run("CommandsWithWorkingDir", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
steps:
- name: multi-cmd
workingDir: /tmp
command:
- pwd
- echo "done"
output: OUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "/tmp\ndone",
})
})
t.Run("DependsOnPreviousStep", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
steps:
- name: step1
command: echo "step1"
output: STEP1_OUT
- name: step2
depends:
- step1
command:
- echo "from step2"
- echo "done"
output: STEP2_OUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"STEP1_OUT": "step1",
"STEP2_OUT": "from step2\ndone",
})
})
}
func TestMultipleCommands_Docker(t *testing.T) {
t.Parallel()
const testImage = "alpine:3"
t.Run("TwoCommandsInContainer", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// Use startup: command to keep container running for multiple commands
dag := th.DAG(t, fmt.Sprintf(`
steps:
- name: multi-cmd
container:
image: %s
startup: command
command: ["sh", "-c", "while true; do sleep 3600; done"]
command:
- echo hello
- echo world
output: OUT
`, testImage))
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "hello\nworld",
})
})
t.Run("CommandsWithEnvInContainer", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// Use startup: command to keep container running for multiple commands
dag := th.DAG(t, fmt.Sprintf(`
steps:
- name: multi-cmd
container:
image: %s
startup: command
command: ["sh", "-c", "while true; do sleep 3600; done"]
env:
- MY_VAR=hello
command:
- printenv MY_VAR
- echo "done"
output: OUT
`, testImage))
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "hello\ndone",
})
})
t.Run("FirstCommandFailsInContainer", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// Use startup: command to keep container running for multiple commands
dag := th.DAG(t, fmt.Sprintf(`
steps:
- name: multi-cmd
container:
image: %s
startup: command
command: ["sh", "-c", "while true; do sleep 3600; done"]
command:
- "false"
- echo "should not run"
output: OUT
`, testImage))
agent := dag.Agent()
agent.RunError(t)
dag.AssertLatestStatus(t, core.Failed)
})
t.Run("SecondCommandFailsInContainer", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// Use startup: command to keep container running for multiple commands
dag := th.DAG(t, fmt.Sprintf(`
steps:
- name: multi-cmd
container:
image: %s
startup: command
command: ["sh", "-c", "while true; do sleep 3600; done"]
command:
- echo "first runs"
- "false"
- echo "should not run"
output: OUT
`, testImage))
agent := dag.Agent()
agent.RunError(t)
dag.AssertLatestStatus(t, core.Failed)
})
t.Run("DAGLevelContainerWithMultipleCommands", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, fmt.Sprintf(`
container:
image: %s
steps:
- name: multi-cmd
command:
- echo hello
- echo world
output: OUT
`, testImage))
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "hello\nworld",
})
})
t.Run("MultipleStepsWithMultipleCommands", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, fmt.Sprintf(`
container:
image: %s
steps:
- name: step1
command:
- echo "step1-cmd1"
- echo "step1-cmd2"
output: STEP1_OUT
- name: step2
depends:
- step1
command:
- echo "step2-cmd1"
- echo "step2-cmd2"
output: STEP2_OUT
`, testImage))
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"STEP1_OUT": "step1-cmd1\nstep1-cmd2",
"STEP2_OUT": "step2-cmd1\nstep2-cmd2",
})
})
// Test step-level container without startup:command - uses default keepalive mode
t.Run("StepContainerWithDefaultKeepalive", func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// No startup:command - should use default keepalive mode
dag := th.DAG(t, fmt.Sprintf(`
steps:
- name: multi-cmd
container:
image: %s
command:
- echo hello
- echo world
output: OUT
`, testImage))
agent := dag.Agent()
agent.RunSuccess(t)
dag.AssertLatestStatus(t, core.Succeeded)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "hello\nworld",
})
})
}
func TestMultipleCommands_Validation(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping Unix shell tests on Windows")
}
t.Parallel()
th := test.Setup(t)
t.Run("JQExecutorRejectsMultipleCommands", func(t *testing.T) {
t.Parallel()
// Create a temp file with the DAG content
tempDir := t.TempDir()
filename := fmt.Sprintf("%s.yaml", uuid.New().String())
testFile := filepath.Join(tempDir, filename)
yamlContent := `
steps:
- name: jq-multi
executor: jq
command:
- ".foo"
- ".bar"
script: '{"foo": "bar"}'
`
err := os.WriteFile(testFile, []byte(yamlContent), 0600)
require.NoError(t, err)
_, err = spec.Load(th.Context, testFile)
require.Error(t, err, "expected error for multiple commands with jq executor")
require.Contains(t, err.Error(), "executor does not support multiple commands")
})
t.Run("HTTPExecutorRejectsMultipleCommands", func(t *testing.T) {
t.Parallel()
// Create a temp file with the DAG content
tempDir := t.TempDir()
filename := fmt.Sprintf("%s.yaml", uuid.New().String())
testFile := filepath.Join(tempDir, filename)
yamlContent := `
steps:
- name: http-multi
executor: http
command:
- "GET https://example.com"
- "POST https://example.com"
`
err := os.WriteFile(testFile, []byte(yamlContent), 0600)
require.NoError(t, err)
_, err = spec.Load(th.Context, testFile)
require.Error(t, err, "expected error for multiple commands with http executor")
require.Contains(t, err.Error(), "executor does not support multiple commands")
})
t.Run("ShellExecutorAllowsMultipleCommands", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
steps:
- name: shell-multi
executor: shell
command:
- echo hello
- echo world
output: OUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "hello\nworld",
})
})
t.Run("CommandExecutorAllowsMultipleCommands", func(t *testing.T) {
t.Parallel()
dag := th.DAG(t, `
steps:
- name: cmd-multi
executor: command
command:
- echo hello
- echo world
output: OUT
`)
agent := dag.Agent()
agent.RunSuccess(t)
// Output is concatenated from all commands
dag.AssertOutputs(t, map[string]any{
"OUT": "hello\nworld",
})
})
}

View File

@ -0,0 +1,439 @@
package integration_test
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/execution"
"github.com/dagu-org/dagu/internal/persistence/filedagrun"
"github.com/dagu-org/dagu/internal/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOutputsCollection(t *testing.T) {
t.Parallel()
tests := []struct {
name string
dagYAML string
runFunc func(*testing.T, context.Context, *test.Agent)
validateFunc func(*testing.T, execution.DAGRunStatus)
validateOutputs func(*testing.T, map[string]string)
}{
{
name: "SimpleStringOutput",
dagYAML: `
steps:
- name: produce-output
command: echo "RESULT=42"
output: RESULT
`,
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
agent.RunSuccess(t)
},
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
require.Equal(t, core.Succeeded, status.Status)
require.Len(t, status.Nodes, 1)
require.Equal(t, core.NodeSucceeded, status.Nodes[0].Status)
},
validateOutputs: func(t *testing.T, outputs map[string]string) {
require.NotNil(t, outputs)
// Output value includes the KEY= prefix from command output
assert.Equal(t, "RESULT=42", outputs["result"]) // SCREAMING_SNAKE to camelCase
},
},
{
name: "OutputWithCustomKey",
dagYAML: `
steps:
- name: produce-output
command: echo "MY_VALUE=hello world"
output:
name: MY_VALUE
key: customKeyName
`,
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
agent.RunSuccess(t)
},
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
require.Equal(t, core.Succeeded, status.Status)
},
validateOutputs: func(t *testing.T, outputs map[string]string) {
require.NotNil(t, outputs)
// Value includes the original KEY= prefix
assert.Equal(t, "MY_VALUE=hello world", outputs["customKeyName"])
_, hasDefault := outputs["myValue"]
assert.False(t, hasDefault, "should not have default key when custom key is specified")
},
},
{
name: "OutputWithOmit",
dagYAML: `
steps:
- name: step1
command: echo "VISIBLE=yes"
output: VISIBLE
- name: step2
command: echo "HIDDEN=secret"
output:
name: HIDDEN
omit: true
`,
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
agent.RunSuccess(t)
},
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
require.Equal(t, core.Succeeded, status.Status)
require.Len(t, status.Nodes, 2)
},
validateOutputs: func(t *testing.T, outputs map[string]string) {
require.NotNil(t, outputs)
assert.Equal(t, "VISIBLE=yes", outputs["visible"])
_, hasHidden := outputs["hidden"]
assert.False(t, hasHidden, "omitted output should not be in outputs.json")
},
},
{
name: "MultipleStepsWithOutputs",
dagYAML: `
steps:
- name: step1
command: echo "COUNT=10"
output: COUNT
- name: step2
command: echo "TOTAL=100"
output: TOTAL
- name: step3
command: echo "STATUS=completed"
output: STATUS
`,
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
agent.RunSuccess(t)
},
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
require.Equal(t, core.Succeeded, status.Status)
require.Len(t, status.Nodes, 3)
},
validateOutputs: func(t *testing.T, outputs map[string]string) {
require.NotNil(t, outputs)
assert.Len(t, outputs, 3)
assert.Equal(t, "COUNT=10", outputs["count"])
assert.Equal(t, "TOTAL=100", outputs["total"])
assert.Equal(t, "STATUS=completed", outputs["status"])
},
},
{
name: "LastOneWinsForDuplicateKeys",
dagYAML: `
steps:
- name: step1
command: echo "VALUE=first"
output: VALUE
- name: step2
depends: [step1]
command: echo "VALUE=second"
output: VALUE
`,
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
agent.RunSuccess(t)
},
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
require.Equal(t, core.Succeeded, status.Status)
},
validateOutputs: func(t *testing.T, outputs map[string]string) {
require.NotNil(t, outputs)
// Last step wins
assert.Equal(t, "VALUE=second", outputs["value"])
},
},
{
name: "NoOutputsProduced",
dagYAML: `
steps:
- name: step1
command: echo "hello"
`,
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
agent.RunSuccess(t)
},
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
require.Equal(t, core.Succeeded, status.Status)
},
validateOutputs: func(t *testing.T, outputs map[string]string) {
// No outputs.json should be created when no outputs
assert.Nil(t, outputs)
},
},
{
name: "OutputWithDollarPrefix",
dagYAML: `
steps:
- name: step1
command: echo "MY_VAR=value123"
output: $MY_VAR
`,
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
agent.RunSuccess(t)
},
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
require.Equal(t, core.Succeeded, status.Status)
},
validateOutputs: func(t *testing.T, outputs map[string]string) {
require.NotNil(t, outputs)
assert.Equal(t, "MY_VAR=value123", outputs["myVar"])
},
},
{
name: "MixedOutputConfigurations",
dagYAML: `
steps:
- name: simple
command: echo "SIMPLE_OUT=simple_value"
output: SIMPLE_OUT
- name: with-key
command: echo "KEYED=keyed_value"
output:
name: KEYED
key: renamedKey
- name: omitted
command: echo "SECRET=secret_value"
output:
name: SECRET
omit: true
`,
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
agent.RunSuccess(t)
},
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
require.Equal(t, core.Succeeded, status.Status)
},
validateOutputs: func(t *testing.T, outputs map[string]string) {
require.NotNil(t, outputs)
assert.Len(t, outputs, 2) // simple + keyed, NOT secret
assert.Equal(t, "SIMPLE_OUT=simple_value", outputs["simpleOut"])
assert.Equal(t, "KEYED=keyed_value", outputs["renamedKey"])
_, hasSecret := outputs["secret"]
assert.False(t, hasSecret)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, tt.dagYAML)
agent := dag.Agent()
// Run the DAG
tt.runFunc(t, agent.Context, agent)
// Validate DAG run status
status, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
require.NoError(t, err)
tt.validateFunc(t, status)
// Read outputs.json if it exists
outputs := readOutputsFile(t, th, dag.DAG)
tt.validateOutputs(t, outputs)
})
}
}
func TestOutputsCollection_FailedDAG(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, `
steps:
- name: step1
command: echo "BEFORE_FAIL=collected"
output: BEFORE_FAIL
- name: step2
depends: [step1]
command: exit 1
- name: step3
depends: [step2]
command: echo "AFTER_FAIL=not_collected"
output: AFTER_FAIL
`)
agent := dag.Agent()
_ = agent.Run(agent.Context)
status, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
require.NoError(t, err)
require.Equal(t, core.Failed, status.Status)
// Outputs from successful steps should still be collected
outputs := readOutputsFile(t, th, dag.DAG)
require.NotNil(t, outputs)
assert.Equal(t, "BEFORE_FAIL=collected", outputs["beforeFail"])
_, hasAfterFail := outputs["afterFail"]
assert.False(t, hasAfterFail, "output from step after failure should not be collected")
}
func TestOutputsCollection_CamelCaseConversion(t *testing.T) {
t.Parallel()
tests := []struct {
envVarName string
expectedKey string
expectedValue string
}{
{"SIMPLE", "simple", "SIMPLE=test_value"},
{"TWO_WORDS", "twoWords", "TWO_WORDS=test_value"},
{"MULTIPLE_WORD_NAME", "multipleWordName", "MULTIPLE_WORD_NAME=test_value"},
{"ALREADY_CAMEL_Case", "alreadyCamelCase", "ALREADY_CAMEL_Case=test_value"},
}
for _, tt := range tests {
t.Run(tt.envVarName, func(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, `
steps:
- name: step1
command: echo "`+tt.envVarName+`=test_value"
output: `+tt.envVarName+`
`)
agent := dag.Agent()
agent.RunSuccess(t)
outputs := readOutputsFile(t, th, dag.DAG)
require.NotNil(t, outputs)
// Value includes the KEY= prefix from the original output
assert.Equal(t, tt.expectedValue, outputs[tt.expectedKey])
})
}
}
func TestOutputsCollection_SecretsMasked(t *testing.T) {
t.Parallel()
th := test.Setup(t)
// Create a temporary secret file
secretValue := "super-secret-api-token-xyz123"
secretFile := th.TempFile(t, "secret.txt", []byte(secretValue))
dag := th.DAG(t, `
secrets:
- name: API_TOKEN
provider: file
key: `+secretFile+`
steps:
- name: output-secret
command: echo "TOKEN=${API_TOKEN}"
output: TOKEN
`)
agent := dag.Agent()
agent.RunSuccess(t)
status, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
require.NoError(t, err)
require.Equal(t, core.Succeeded, status.Status)
// Read outputs.json
outputs := readOutputsFile(t, th, dag.DAG)
require.NotNil(t, outputs)
// The output value should contain the masked secret, not the actual value
tokenOutput := outputs["token"]
require.NotEmpty(t, tokenOutput)
assert.NotContains(t, tokenOutput, secretValue, "secret value should be masked in outputs")
assert.Contains(t, tokenOutput, "*******", "masked placeholder should appear in outputs")
}
func TestOutputsCollection_MetadataIncluded(t *testing.T) {
t.Parallel()
th := test.Setup(t)
dag := th.DAG(t, `
steps:
- name: step1
command: echo "RESULT=42"
output: RESULT
`)
agent := dag.Agent()
agent.RunSuccess(t)
status, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
require.NoError(t, err)
// Read full outputs including metadata
fullOutputs := readFullOutputsFile(t, th, dag.DAG)
require.NotNil(t, fullOutputs)
// Validate metadata
assert.Equal(t, dag.Name, fullOutputs.Metadata.DAGName)
assert.Equal(t, status.DAGRunID, fullOutputs.Metadata.DAGRunID)
assert.NotEmpty(t, fullOutputs.Metadata.AttemptID)
assert.Equal(t, "succeeded", fullOutputs.Metadata.Status)
assert.NotEmpty(t, fullOutputs.Metadata.CompletedAt)
// Validate outputs are still present
assert.Equal(t, "RESULT=42", fullOutputs.Outputs["result"])
}
// readOutputsFile reads the outputs.json file for a given DAG run
// Returns just the outputs map for backward compatibility with existing tests
func readOutputsFile(t *testing.T, th test.Helper, dag *core.DAG) map[string]string {
t.Helper()
fullOutputs := readFullOutputsFile(t, th, dag)
if fullOutputs == nil {
return nil
}
return fullOutputs.Outputs
}
// readFullOutputsFile reads the full outputs.json file including metadata
func readFullOutputsFile(t *testing.T, th test.Helper, dag *core.DAG) *execution.DAGRunOutputs {
t.Helper()
// Find the attempt directory
dagRunsDir := th.Config.Paths.DAGRunsDir
dagRunDir := filepath.Join(dagRunsDir, dag.Name, "dag-runs")
// Walk to find the outputs.json file
var outputsPath string
_ = filepath.Walk(dagRunDir, func(path string, info os.FileInfo, err error) error {
require.NoError(t, err)
if info.Name() == filedagrun.OutputsFile {
outputsPath = path
return filepath.SkipAll
}
return nil
})
if outputsPath == "" {
return nil
}
data, err := os.ReadFile(outputsPath)
require.NoError(t, err)
var outputs execution.DAGRunOutputs
require.NoError(t, json.Unmarshal(data, &outputs))
// Return nil if old format (no metadata)
require.NotEmpty(t, outputs.Metadata.DAGRunID)
return &outputs
}

View File

@ -8,6 +8,8 @@ import (
"time"
"github.com/dagu-org/dagu/internal/cmd"
"github.com/dagu-org/dagu/internal/common/config"
"github.com/dagu-org/dagu/internal/common/stringutil"
"github.com/dagu-org/dagu/internal/core"
"github.com/dagu-org/dagu/internal/core/execution"
"github.com/dagu-org/dagu/internal/core/spec"
@ -137,6 +139,304 @@ steps:
require.Less(t, duration, 20*time.Second, "took too long: %v", duration)
}
// TestGlobalQueueMaxConcurrency verifies that a global queue with maxConcurrency > 1
// processes multiple DAGs concurrently, and that the DAG's maxActiveRuns doesn't
// override the global queue's maxConcurrency setting.
func TestGlobalQueueMaxConcurrency(t *testing.T) {
t.Parallel()
// Configure a global queue with maxConcurrency = 3 BEFORE setup
// so it gets written to the config file
th := test.SetupCommand(t, test.WithConfigMutator(func(cfg *config.Config) {
cfg.Queues.Enabled = true
cfg.Queues.Config = []config.QueueConfig{
{Name: "global-queue", MaxActiveRuns: 3},
}
}))
// Create a DAG with maxActiveRuns = 1 that uses the global queue
// Each DAG sleeps for 1 second to ensure we can detect concurrent vs sequential execution
dagContent := `name: concurrent-test
queue: global-queue
maxActiveRuns: 1
steps:
- name: sleep-step
command: sleep 1
`
require.NoError(t, os.MkdirAll(th.Config.Paths.DAGsDir, 0755))
dagFile := filepath.Join(th.Config.Paths.DAGsDir, "concurrent-test.yaml")
require.NoError(t, os.WriteFile(dagFile, []byte(dagContent), 0644))
dag, err := spec.Load(th.Context, dagFile)
require.NoError(t, err)
// Enqueue 3 items
runIDs := make([]string, 3)
for i := 0; i < 3; i++ {
dagRunID := uuid.New().String()
runIDs[i] = dagRunID
att, err := th.DAGRunStore.CreateAttempt(th.Context, dag, time.Now(), dagRunID, execution.NewDAGRunAttemptOptions{})
require.NoError(t, err)
logFile := filepath.Join(th.Config.Paths.LogDir, dag.Name, dagRunID+".log")
require.NoError(t, os.MkdirAll(filepath.Dir(logFile), 0755))
dagStatus := transform.NewStatusBuilder(dag).Create(dagRunID, core.Queued, 0, time.Time{},
transform.WithLogFilePath(logFile),
transform.WithAttemptID(att.ID()),
transform.WithHierarchyRefs(
execution.NewDAGRunRef(dag.Name, dagRunID),
execution.DAGRunRef{},
),
)
require.NoError(t, att.Open(th.Context))
require.NoError(t, att.Write(th.Context, dagStatus))
require.NoError(t, att.Close(th.Context))
err = th.QueueStore.Enqueue(th.Context, "global-queue", execution.QueuePriorityLow, execution.NewDAGRunRef(dag.Name, dagRunID))
require.NoError(t, err)
}
// Verify queue has 3 items
queuedItems, err := th.QueueStore.List(th.Context, "global-queue")
require.NoError(t, err)
require.Len(t, queuedItems, 3)
// Start scheduler
schedulerDone := make(chan error, 1)
daguHome := filepath.Dir(th.Config.Paths.DAGsDir)
go func() {
ctx, cancel := context.WithTimeout(th.Context, 30*time.Second)
defer cancel()
thCopy := th
thCopy.Context = ctx
schedulerDone <- thCopy.RunCommandWithError(t, cmd.Scheduler(), test.CmdTest{
Args: []string{
"scheduler",
"--dagu-home", daguHome,
},
ExpectedOut: []string{"Scheduler started"},
})
}()
// Wait until all DAGs complete
require.Eventually(t, func() bool {
remaining, err := th.QueueStore.List(th.Context, "global-queue")
if err != nil {
return false
}
t.Logf("Queue: %d/3 items remaining", len(remaining))
return len(remaining) == 0
}, 15*time.Second, 200*time.Millisecond, "Queue items should be processed")
th.Cancel()
select {
case <-schedulerDone:
case <-time.After(5 * time.Second):
}
// Collect start times from all DAG runs
var startTimes []time.Time
for _, runID := range runIDs {
attempt, err := th.DAGRunStore.FindAttempt(th.Context, execution.NewDAGRunRef(dag.Name, runID))
require.NoError(t, err)
status, err := attempt.ReadStatus(th.Context)
require.NoError(t, err)
startedAt, err := stringutil.ParseTime(status.StartedAt)
require.NoError(t, err, "Failed to parse start time for run %s", runID)
require.False(t, startedAt.IsZero(), "Start time is zero for run %s", runID)
startTimes = append(startTimes, startedAt)
}
// All 3 DAGs should have started
require.Len(t, startTimes, 3, "All 3 DAGs should have started")
// Find the max difference between start times
// If they ran concurrently, all should start within ~500ms
// If they ran sequentially (maxConcurrency=1), they'd be ~1s apart
var maxDiff time.Duration
for i := 0; i < len(startTimes); i++ {
for j := i + 1; j < len(startTimes); j++ {
diff := startTimes[i].Sub(startTimes[j]).Abs()
if diff > maxDiff {
maxDiff = diff
}
}
}
t.Logf("Start times: %v", startTimes)
t.Logf("Max difference between start times: %v", maxDiff)
// All DAGs should start within 2 seconds of each other (concurrent execution)
// If maxConcurrency was incorrectly set to 1, they would start ~3+ seconds apart
// (due to 1 second sleep in each DAG + processing overhead)
require.Less(t, maxDiff, 2*time.Second,
"All 3 DAGs should start concurrently (within 2s), but max diff was %v", maxDiff)
}
// TestDAGQueueMaxActiveRunsFirstBatch verifies that when a DAG-based (non-global)
// queue is first encountered, all items up to maxActiveRuns are processed in the
// first batch, not just 1.
//
// This test covers the bug where dynamically created queues were initialized with
// maxConcurrency=1, causing only 1 DAG to start initially even when maxActiveRuns > 1.
// The fix reads the DAG's maxActiveRuns before selecting items for the first batch.
func TestDAGQueueMaxActiveRunsFirstBatch(t *testing.T) {
t.Parallel()
th := test.SetupCommand(t, test.WithConfigMutator(func(cfg *config.Config) {
cfg.Queues.Enabled = true
// No global queues configured - we want to test DAG-based queues
}))
// Create a DAG with maxActiveRuns = 3 (no queue: field, so it uses DAG-based queue)
// Each DAG sleeps for 2 seconds to ensure we can detect concurrent vs sequential execution
dagContent := `name: dag-queue-test
maxActiveRuns: 3
steps:
- name: sleep-step
command: sleep 2
`
require.NoError(t, os.MkdirAll(th.Config.Paths.DAGsDir, 0755))
dagFile := filepath.Join(th.Config.Paths.DAGsDir, "dag-queue-test.yaml")
require.NoError(t, os.WriteFile(dagFile, []byte(dagContent), 0644))
dag, err := spec.Load(th.Context, dagFile)
require.NoError(t, err)
// Verify queue name is the DAG name (DAG-based queue)
queueName := dag.ProcGroup()
require.Equal(t, "dag-queue-test", queueName, "DAG should use its name as queue")
// Enqueue 3 items
runIDs := make([]string, 3)
for i := 0; i < 3; i++ {
dagRunID := uuid.New().String()
runIDs[i] = dagRunID
att, err := th.DAGRunStore.CreateAttempt(th.Context, dag, time.Now(), dagRunID, execution.NewDAGRunAttemptOptions{})
require.NoError(t, err)
logFile := filepath.Join(th.Config.Paths.LogDir, dag.Name, dagRunID+".log")
require.NoError(t, os.MkdirAll(filepath.Dir(logFile), 0755))
dagStatus := transform.NewStatusBuilder(dag).Create(dagRunID, core.Queued, 0, time.Time{},
transform.WithLogFilePath(logFile),
transform.WithAttemptID(att.ID()),
transform.WithHierarchyRefs(
execution.NewDAGRunRef(dag.Name, dagRunID),
execution.DAGRunRef{},
),
)
require.NoError(t, att.Open(th.Context))
require.NoError(t, att.Write(th.Context, dagStatus))
require.NoError(t, att.Close(th.Context))
err = th.QueueStore.Enqueue(th.Context, queueName, execution.QueuePriorityLow, execution.NewDAGRunRef(dag.Name, dagRunID))
require.NoError(t, err)
}
// Verify queue has 3 items
queuedItems, err := th.QueueStore.List(th.Context, queueName)
require.NoError(t, err)
require.Len(t, queuedItems, 3)
t.Logf("Enqueued 3 items to DAG-based queue %q", queueName)
// Start scheduler
schedulerDone := make(chan error, 1)
daguHome := filepath.Dir(th.Config.Paths.DAGsDir)
go func() {
ctx, cancel := context.WithTimeout(th.Context, 30*time.Second)
defer cancel()
thCopy := th
thCopy.Context = ctx
schedulerDone <- thCopy.RunCommandWithError(t, cmd.Scheduler(), test.CmdTest{
Args: []string{
"scheduler",
"--dagu-home", daguHome,
},
ExpectedOut: []string{"Scheduler started"},
})
}()
// Wait until all DAGs complete
startTime := time.Now()
require.Eventually(t, func() bool {
remaining, err := th.QueueStore.List(th.Context, queueName)
if err != nil {
return false
}
t.Logf("Queue: %d/3 items remaining", len(remaining))
return len(remaining) == 0
}, 20*time.Second, 200*time.Millisecond, "Queue items should be processed")
totalDuration := time.Since(startTime)
t.Logf("All items processed in %v", totalDuration)
th.Cancel()
select {
case <-schedulerDone:
case <-time.After(5 * time.Second):
}
// Collect start times from all DAG runs
var startTimes []time.Time
for _, runID := range runIDs {
attempt, err := th.DAGRunStore.FindAttempt(th.Context, execution.NewDAGRunRef(dag.Name, runID))
require.NoError(t, err)
status, err := attempt.ReadStatus(th.Context)
require.NoError(t, err)
startedAt, err := stringutil.ParseTime(status.StartedAt)
require.NoError(t, err, "Failed to parse start time for run %s", runID)
require.False(t, startedAt.IsZero(), "Start time is zero for run %s", runID)
startTimes = append(startTimes, startedAt)
}
// All 3 DAGs should have started
require.Len(t, startTimes, 3, "All 3 DAGs should have started")
// Find the max difference between start times
var maxDiff time.Duration
for i := 0; i < len(startTimes); i++ {
for j := i + 1; j < len(startTimes); j++ {
diff := startTimes[i].Sub(startTimes[j]).Abs()
if diff > maxDiff {
maxDiff = diff
}
}
}
t.Logf("Start times: %v", startTimes)
t.Logf("Max difference between start times: %v", maxDiff)
// KEY ASSERTION: All 3 DAGs should start in the FIRST batch (concurrently)
// If the bug exists (queue initialized with maxConcurrency=1), they would start
// sequentially: first at 0s, second at ~2s, third at ~4s (total ~6s+)
// With the fix, all 3 start within the first batch, so max diff <= 2s
require.LessOrEqual(t, maxDiff, 2*time.Second,
"All 3 DAGs should start in first batch (within 2s), but max diff was %v. "+
"This suggests maxActiveRuns was not applied to the first batch.", maxDiff)
// Also verify total time is reasonable
// - Concurrent execution: ~2s sleep + scheduler overhead (~4-6s total)
// - Sequential execution (bug): ~6s sleep + overhead (~8-10s total)
require.Less(t, totalDuration, 8*time.Second,
"Total processing time should be under 8s for concurrent execution, but was %v", totalDuration)
}
// TestCronScheduleRunsTwice verifies that a DAG with */1 * * * * schedule
// runs twice in two minutes.
func TestCronScheduleRunsTwice(t *testing.T) {

Some files were not shown because too many files have changed in this diff Show More