mirror of
https://github.com/dagu-org/dagu.git
synced 2025-12-28 06:34:22 +00:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4633850d2d | ||
|
|
140afbaffe | ||
|
|
345ef13565 | ||
|
|
df257b49e8 | ||
|
|
03e692d6d8 | ||
|
|
3e9d49d1e0 | ||
|
|
64f1884e45 | ||
|
|
1fb869dda8 | ||
|
|
05452f9669 | ||
|
|
d0193cb8c0 | ||
|
|
8945b926b5 | ||
|
|
d87b8c6dff | ||
|
|
dcb3cf1570 | ||
|
|
d2b3d8fdd3 | ||
|
|
f3d4577e42 | ||
|
|
5d90d2744f | ||
|
|
f04786c5e1 | ||
|
|
887355f0ca | ||
|
|
eef457b4c2 | ||
|
|
c72623154d | ||
|
|
5d6e50df04 | ||
|
|
cb83c59a6d | ||
|
|
be3e71b79a | ||
|
|
9841e6ed70 | ||
|
|
3ebfa3cbf2 | ||
|
|
23a29f336a | ||
|
|
5e7ce83afa | ||
|
|
0498640f11 | ||
|
|
a75f667f95 |
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
@ -42,7 +42,7 @@
|
||||
"request": "launch",
|
||||
"mode": "auto",
|
||||
"program": "${workspaceFolder}/cmd/",
|
||||
"args": ["start", "--no-queue", "${input:pathToSpec}"]
|
||||
"args": ["start", "${input:pathToSpec}"]
|
||||
},
|
||||
{
|
||||
"name": "Retry",
|
||||
|
||||
122
README.md
122
README.md
@ -31,23 +31,17 @@ Built for developers who want powerful workflow orchestration without the operat
|
||||
|
||||
## Why Dagu?
|
||||
|
||||
### 🚀 Zero Dependencies
|
||||
**Single binary. No database, no message broker.** Deploy anywhere in seconds—from your laptop to bare metal servers to Kubernetes. Everything is stored in plain files (XDG compliant), making it transparent, portable, and easy to backup.
|
||||
Many workflow orchestrators already exist, and Apache Airflow is a well known example. In Airflow, DAGs are loaded from Python source files, so defining workflows typically means writing and maintaining Python code. In real deployments, Airflow commonly involves multiple running components (for example, scheduler, webserver, metadata database, and workers) and DAG files often need to be synchronized across them, which can increase operational complexity.
|
||||
|
||||
### 🧩 Composable Nested Workflows
|
||||
**Build complex pipelines from reusable building blocks.** Define sub-workflows that can be called with parameters, executed in parallel, and fully monitored in the UI. See execution traces for every nested level—no black boxes.
|
||||
Dagu is a self-contained workflow engine where workflows are defined in simple YAML and executed with a single binary. It is designed to run without requiring external databases or message brokers, using local files for definitions, logs, and metadata. Because it orchestrates commands rather than forcing you into a specific programming model, it is easy to integrate existing shell scripts and operational commands as they are. Our goal is to make Dagu an ideal workflow engine for small teams that want orchestration power with minimal setup and operational overhead.
|
||||
|
||||
### 🌐 Language Agnostic
|
||||
**Use your existing scripts without modification.** No need to wrap everything in Python decorators or rewrite logic. Dagu orchestrates shell commands, Docker containers, SSH commands, or HTTP calls—whatever you already have.
|
||||
|
||||
### ⚡ Distributed Execution
|
||||
**Built-in queue system with intelligent task routing.** Route tasks to workers based on labels (GPU, region, compliance requirements). Automatic service registry and health monitoring included—no external coordination service needed.
|
||||
|
||||
### 🎯 Production Ready
|
||||
**Not a toy.** Battle-tested error handling with exponential backoff retries, lifecycle hooks (onSuccess, onFailure, onExit), real-time log streaming, email notifications, Prometheus metrics, and OpenTelemetry tracing out of the box. Built-in user management with role-based access control (RBAC) for team environments.
|
||||
|
||||
### 🎨 Modern Web UI
|
||||
**Beautiful UI that actually helps you debug.** Live log tailing, DAG visualization with Gantt charts, execution history with full lineage, and drill-down into nested sub-workflows. Dark mode included.
|
||||
## Highlights
|
||||
|
||||
- Single binary file installation
|
||||
- Declarative YAML format for defining DAGs
|
||||
- Web UI for visually managing, rerunning, and monitoring pipelines
|
||||
- Use existing programs without any modification
|
||||
- Self-contained, with no need for a DBMS
|
||||
|
||||
## Quick Start
|
||||
|
||||
@ -154,6 +148,62 @@ dagu start-all
|
||||
|
||||
Visit http://localhost:8080
|
||||
|
||||
## Quick Look for Workflow Definitions
|
||||
|
||||
### Sequential Steps
|
||||
|
||||
Steps execute one after another:
|
||||
|
||||
```yaml
|
||||
type: chain
|
||||
steps:
|
||||
- command: echo "Hello, dagu!"
|
||||
- command: echo "This is a second step"
|
||||
```
|
||||
|
||||
```mermaid
|
||||
%%{init: {'theme': 'base', 'themeVariables': {'background': '#3A322C', 'primaryTextColor': '#fff', 'lineColor': '#888'}}}%%
|
||||
graph LR
|
||||
A["Step 1"] --> B["Step 2"]
|
||||
style A fill:#3A322C,stroke:green,stroke-width:1.6px,color:#fff
|
||||
style B fill:#3A322C,stroke:lime,stroke-width:1.6px,color:#fff
|
||||
```
|
||||
|
||||
### Parallel Steps
|
||||
|
||||
Steps with dependencies run in parallel:
|
||||
|
||||
```yaml
|
||||
type: graph
|
||||
steps:
|
||||
- id: step_1
|
||||
command: echo "Step 1"
|
||||
- id: step_2a
|
||||
command: echo "Step 2a - runs in parallel"
|
||||
depends: [step_1]
|
||||
- id: step_2b
|
||||
command: echo "Step 2b - runs in parallel"
|
||||
depends: [step_1]
|
||||
- id: step_3
|
||||
command: echo "Step 3 - waits for parallel steps"
|
||||
depends: [step_2a, step_2b]
|
||||
```
|
||||
|
||||
```mermaid
|
||||
%%{init: {'theme': 'base', 'themeVariables': {'background': '#3A322C', 'primaryTextColor': '#fff', 'lineColor': '#888'}}}%%
|
||||
graph LR
|
||||
A[step_1] --> B[step_2a]
|
||||
A --> C[step_2b]
|
||||
B --> D[step_3]
|
||||
C --> D
|
||||
style A fill:#3A322C,stroke:green,stroke-width:1.6px,color:#fff
|
||||
style B fill:#3A322C,stroke:lime,stroke-width:1.6px,color:#fff
|
||||
style C fill:#3A322C,stroke:lime,stroke-width:1.6px,color:#fff
|
||||
style D fill:#3A322C,stroke:lightblue,stroke-width:1.6px,color:#fff
|
||||
```
|
||||
|
||||
For more examples, see the [Examples](https://docs.dagu.cloud/writing-workflows/examples) documentation.
|
||||
|
||||
## Docker-Compose
|
||||
|
||||
Clone the repository and run with Docker Compose:
|
||||
@ -496,6 +546,37 @@ For discussions, support, and sharing ideas, join our community on [Discord](htt
|
||||
|
||||
Changelog of recent updates can be found in the [Changelog](https://docs.dagu.cloud/reference/changelog) section of the documentation.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
### Sponsors & Supporters
|
||||
|
||||
<div align="center">
|
||||
<h3>💜 Premium Sponsors</h3>
|
||||
<a href="https://github.com/slashbinlabs">
|
||||
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fslashbinlabs.png&w=150&h=150&fit=cover&mask=circle" width="100" height="100" alt="@slashbinlabs">
|
||||
</a>
|
||||
|
||||
<h3>✨ Supporters</h3>
|
||||
<a href="https://github.com/disizmj">
|
||||
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fdisizmj.png&w=128&h=128&fit=cover&mask=circle" width="50" height="50" alt="@disizmj" style="margin: 5px;">
|
||||
</a>
|
||||
<a href="https://github.com/Arvintian">
|
||||
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2FArvintian.png&w=128&h=128&fit=cover&mask=circle" width="50" height="50" alt="@Arvintian" style="margin: 5px;">
|
||||
</a>
|
||||
<a href="https://github.com/yurivish">
|
||||
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fyurivish.png&w=128&h=128&fit=cover&mask=circle" width="50" height="50" alt="@yurivish" style="margin: 5px;">
|
||||
</a>
|
||||
<a href="https://github.com/jayjoshi64">
|
||||
<img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fjayjoshi64.png&w=128&h=128&fit=cover&mask=circle" width="50" height="50" alt="@jayjoshi64" style="margin: 5px;">
|
||||
</a>
|
||||
|
||||
<br/><br/>
|
||||
|
||||
<a href="https://github.com/sponsors/dagu-org">
|
||||
<img src="https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86" width="150" alt="Sponsor">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions of all kinds! Whether you're a developer, a designer, or a user, your help is valued. Here are a few ways to get involved:
|
||||
@ -507,8 +588,6 @@ We welcome contributions of all kinds! Whether you're a developer, a designer, o
|
||||
|
||||
For more details, see our [Contribution Guide](./CONTRIBUTING.md).
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
### Contributors
|
||||
|
||||
<a href="https://github.com/dagu-org/dagu/graphs/contributors">
|
||||
@ -517,15 +596,6 @@ For more details, see our [Contribution Guide](./CONTRIBUTING.md).
|
||||
|
||||
Thanks to all the contributors who have helped make Dagu better! Your contributions, whether through code, documentation, or feedback, are invaluable to the project.
|
||||
|
||||
### Sponsors & Supporters
|
||||
|
||||
<a href="https://github.com/disizmj"><img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fdisizmj.png&w=128&h=128&fit=cover&mask=circle" width="64" height="64" alt="@disizmj"></a>
|
||||
<a href="https://github.com/Arvintian"><img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2FArvintian.png&w=128&h=128&fit=cover&mask=circle" width="64" height="64" alt="@Arvintian"></a>
|
||||
<a href="https://github.com/yurivish"><img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fyurivish.png&w=128&h=128&fit=cover&mask=circle" width="64" height="64" alt="@yurivish"></a>
|
||||
<a href="https://github.com/jayjoshi64"><img src="https://wsrv.nl/?url=https%3A%2F%2Fgithub.com%2Fjayjoshi64.png&w=128&h=128&fit=cover&mask=circle" width="64" height="64" alt="@jayjoshi64"></a>
|
||||
|
||||
Thanks for supporting Dagu’s development! Join our supporters: [GitHub Sponsors](https://github.com/sponsors/dagu-org)
|
||||
|
||||
## License
|
||||
|
||||
GNU GPLv3 - See [LICENSE](./LICENSE)
|
||||
|
||||
1409
api/v2/api.gen.go
1409
api/v2/api.gen.go
File diff suppressed because it is too large
Load Diff
496
api/v2/api.yaml
496
api/v2/api.yaml
@ -37,6 +37,8 @@ tags:
|
||||
description: "Authentication operations (login, logout, token management)"
|
||||
- name: "users"
|
||||
description: "User management operations (CRUD, password management)"
|
||||
- name: "api-keys"
|
||||
description: "API key management operations (admin only)"
|
||||
|
||||
paths:
|
||||
/health:
|
||||
@ -56,10 +58,6 @@ paths:
|
||||
default:
|
||||
description: "Unexpected error"
|
||||
|
||||
# ============================================================================
|
||||
# Authentication Endpoints
|
||||
# ============================================================================
|
||||
|
||||
/auth/login:
|
||||
post:
|
||||
summary: "Authenticate user and obtain JWT token"
|
||||
@ -435,6 +433,225 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
# API Key Management (Admin only)
|
||||
/api-keys:
|
||||
get:
|
||||
summary: "List all API keys"
|
||||
description: "Returns all API keys. Requires admin role."
|
||||
operationId: "listAPIKeys"
|
||||
tags:
|
||||
- "api-keys"
|
||||
responses:
|
||||
"200":
|
||||
description: "List of API keys"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/APIKeysListResponse"
|
||||
"401":
|
||||
description: "Not authenticated"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"403":
|
||||
description: "Requires admin role"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
default:
|
||||
description: "Error"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
post:
|
||||
summary: "Create API key"
|
||||
description: "Full key returned only in this response"
|
||||
operationId: "createAPIKey"
|
||||
tags:
|
||||
- "api-keys"
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CreateAPIKeyRequest"
|
||||
responses:
|
||||
"201":
|
||||
description: "Created"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CreateAPIKeyResponse"
|
||||
"400":
|
||||
description: "Invalid request"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"401":
|
||||
description: "Not authenticated"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"403":
|
||||
description: "Requires admin role"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: "Name already exists"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
default:
|
||||
description: "Error"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/api-keys/{keyId}:
|
||||
get:
|
||||
summary: "Get API key"
|
||||
description: "Returns API key by ID. Requires admin role."
|
||||
operationId: "getAPIKey"
|
||||
tags:
|
||||
- "api-keys"
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/APIKeyId"
|
||||
responses:
|
||||
"200":
|
||||
description: "API key details"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/APIKeyResponse"
|
||||
"401":
|
||||
description: "Not authenticated"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"403":
|
||||
description: "Requires admin role"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"404":
|
||||
description: "Not found"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
default:
|
||||
description: "Error"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
patch:
|
||||
summary: "Update API key"
|
||||
description: "Updates API key info. Requires admin role."
|
||||
operationId: "updateAPIKey"
|
||||
tags:
|
||||
- "api-keys"
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/APIKeyId"
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/UpdateAPIKeyRequest"
|
||||
responses:
|
||||
"200":
|
||||
description: "Updated API key"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/APIKeyResponse"
|
||||
"400":
|
||||
description: "Invalid request"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"401":
|
||||
description: "Not authenticated"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"403":
|
||||
description: "Requires admin role"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"404":
|
||||
description: "Not found"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"409":
|
||||
description: "Name already exists"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
default:
|
||||
description: "Error"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
delete:
|
||||
summary: "Delete API key"
|
||||
description: "Revokes an API key. Requires admin role."
|
||||
operationId: "deleteAPIKey"
|
||||
tags:
|
||||
- "api-keys"
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/APIKeyId"
|
||||
responses:
|
||||
"204":
|
||||
description: "API key deleted"
|
||||
"401":
|
||||
description: "Not authenticated"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"403":
|
||||
description: "Requires admin role"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
"404":
|
||||
description: "Not found"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
default:
|
||||
description: "Error"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/workers:
|
||||
get:
|
||||
summary: "List distributed workers"
|
||||
@ -797,6 +1014,10 @@ paths:
|
||||
queue:
|
||||
type: string
|
||||
description: "Override the DAG-level queue definition"
|
||||
singleton:
|
||||
type: boolean
|
||||
description: "If true, prevent enqueuing if DAG is already running or queued (returns 409 conflict)"
|
||||
default: false
|
||||
responses:
|
||||
"200":
|
||||
description: "A successful response"
|
||||
@ -810,6 +1031,12 @@ paths:
|
||||
description: "ID of the created DAG-run"
|
||||
required:
|
||||
- dagRunId
|
||||
"409":
|
||||
description: "DAG is already running or queued (singleton mode)"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
default:
|
||||
description: "Generic error response"
|
||||
content:
|
||||
@ -1472,7 +1699,7 @@ paths:
|
||||
/dag-runs/{name}/{dagRunId}/sub-dag-runs:
|
||||
get:
|
||||
summary: "Get sub DAG runs with timing info"
|
||||
description: "Retrieves timing and status information for all sub DAG runs (including repeated executions) of a specific step"
|
||||
description: "Retrieves timing and status information for all sub DAG runs (including repeated executions) of a specific step. When parentSubDAGRunId is provided, returns sub-runs of that specific sub DAG run (for multi-level nested DAGs)."
|
||||
operationId: "getSubDAGRuns"
|
||||
tags:
|
||||
- "dag-runs"
|
||||
@ -1480,6 +1707,12 @@ paths:
|
||||
- $ref: "#/components/parameters/RemoteNode"
|
||||
- $ref: "#/components/parameters/DAGName"
|
||||
- $ref: "#/components/parameters/DAGRunId"
|
||||
- name: "parentSubDAGRunId"
|
||||
in: "query"
|
||||
required: false
|
||||
schema:
|
||||
type: "string"
|
||||
description: "Optional parent sub DAG run ID. When provided, returns sub-runs of this specific sub DAG run instead of the root DAG run. Used for multi-level nested DAGs."
|
||||
responses:
|
||||
"200":
|
||||
description: "A successful response"
|
||||
@ -1569,6 +1802,37 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/dag-runs/{name}/{dagRunId}/outputs:
|
||||
get:
|
||||
summary: "Retrieve collected outputs from a DAG-run"
|
||||
description: "Fetches the outputs.json file containing all step outputs collected during the DAG-run execution. Returns the outputs as a JSON object where keys are the output names (converted from UPPER_CASE to camelCase by default, or custom key if specified) and values are the captured output strings."
|
||||
operationId: "getDAGRunOutputs"
|
||||
tags:
|
||||
- "dag-runs"
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/RemoteNode"
|
||||
- $ref: "#/components/parameters/DAGName"
|
||||
- $ref: "#/components/parameters/DAGRunId"
|
||||
responses:
|
||||
"200":
|
||||
description: "Successfully retrieved outputs. Returns the collected outputs with metadata. If the DAG-run completed but captured no outputs, returns an empty outputs object."
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/DAGRunOutputs"
|
||||
"404":
|
||||
description: "DAG-run not found"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
default:
|
||||
description: "Generic error response"
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Error"
|
||||
|
||||
/dag-runs/{name}/{dagRunId}/retry:
|
||||
post:
|
||||
summary: "Retry DAG-run execution"
|
||||
@ -2065,6 +2329,15 @@ components:
|
||||
type: string
|
||||
minLength: 1
|
||||
|
||||
APIKeyId:
|
||||
name: keyId
|
||||
in: path
|
||||
description: unique identifier of the API key
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
minLength: 1
|
||||
|
||||
PerPage:
|
||||
name: perPage
|
||||
in: query
|
||||
@ -2233,6 +2506,9 @@ components:
|
||||
- "max_run_reached"
|
||||
- "not_running"
|
||||
- "already_exists"
|
||||
- "auth.unauthorized"
|
||||
- "auth.token_invalid"
|
||||
- "auth.forbidden"
|
||||
|
||||
Stream:
|
||||
type: string
|
||||
@ -2737,6 +3013,52 @@ components:
|
||||
required:
|
||||
- nodes
|
||||
|
||||
DAGRunOutputs:
|
||||
type: object
|
||||
description: "Collected outputs from step executions in a DAG-run, including execution metadata. If the DAG-run completed but no outputs were captured, the outputs object will be empty and metadata fields may be empty strings."
|
||||
required:
|
||||
- metadata
|
||||
- outputs
|
||||
properties:
|
||||
metadata:
|
||||
$ref: "#/components/schemas/OutputsMetadata"
|
||||
outputs:
|
||||
type: object
|
||||
description: "Collected step outputs as key-value pairs. Keys are output names (UPPER_CASE converted to camelCase by default, or custom key if specified) and values are the captured output strings. Empty object if no outputs were captured."
|
||||
additionalProperties:
|
||||
type: string
|
||||
example:
|
||||
totalCount: "42"
|
||||
resultFile: "/path/to/result.txt"
|
||||
config: '{"key": "value"}'
|
||||
|
||||
OutputsMetadata:
|
||||
type: object
|
||||
description: "Execution context metadata for the outputs"
|
||||
required:
|
||||
- dagName
|
||||
- dagRunId
|
||||
- attemptId
|
||||
- status
|
||||
- completedAt
|
||||
properties:
|
||||
dagName:
|
||||
$ref: "#/components/schemas/DAGName"
|
||||
dagRunId:
|
||||
$ref: "#/components/schemas/DAGRunId"
|
||||
attemptId:
|
||||
type: string
|
||||
description: "Attempt identifier within the run"
|
||||
status:
|
||||
$ref: "#/components/schemas/StatusLabel"
|
||||
completedAt:
|
||||
type: string
|
||||
format: date-time
|
||||
description: "RFC3339 timestamp when execution completed"
|
||||
params:
|
||||
type: string
|
||||
description: "JSON-serialized parameters passed to the DAG"
|
||||
|
||||
Node:
|
||||
type: object
|
||||
description: "Status of an individual step within a DAG-run"
|
||||
@ -2806,8 +3128,7 @@ components:
|
||||
description: "Detailed information for a sub DAG-run including timing and status"
|
||||
properties:
|
||||
dagRunId:
|
||||
type: string
|
||||
description: "Unique identifier for the sub DAG-run"
|
||||
$ref: "#/components/schemas/DAGRunId"
|
||||
params:
|
||||
type: string
|
||||
description: "Parameters passed to the sub DAG-run in JSON format"
|
||||
@ -2843,12 +3164,11 @@ components:
|
||||
dir:
|
||||
type: string
|
||||
description: "Working directory for executing the step's command"
|
||||
cmdWithArgs:
|
||||
type: string
|
||||
description: "Complete command string including arguments to execute"
|
||||
command:
|
||||
type: string
|
||||
description: "Base command to execute without arguments"
|
||||
commands:
|
||||
type: array
|
||||
description: "List of commands to execute sequentially"
|
||||
items:
|
||||
$ref: "#/components/schemas/CommandEntry"
|
||||
script:
|
||||
type: string
|
||||
description: "Script content if the step executes a script file"
|
||||
@ -2861,11 +3181,6 @@ components:
|
||||
output:
|
||||
type: string
|
||||
description: "Variable name to store the step's output"
|
||||
args:
|
||||
type: array
|
||||
description: "List of arguments to pass to the command"
|
||||
items:
|
||||
type: string
|
||||
call:
|
||||
type: string
|
||||
description: "The name of the DAG to execute as a sub DAG-run"
|
||||
@ -3046,6 +3361,21 @@ components:
|
||||
items:
|
||||
type: integer
|
||||
|
||||
CommandEntry:
|
||||
type: object
|
||||
description: "A command with its arguments"
|
||||
properties:
|
||||
command:
|
||||
type: string
|
||||
description: "The command to execute"
|
||||
args:
|
||||
type: array
|
||||
description: "Arguments for the command"
|
||||
items:
|
||||
type: string
|
||||
required:
|
||||
- command
|
||||
|
||||
ListTagResponse:
|
||||
type: object
|
||||
description: "Response object for listing all tags"
|
||||
@ -3124,26 +3454,20 @@ components:
|
||||
description: "Information about a task currently being executed"
|
||||
properties:
|
||||
dagRunId:
|
||||
type: string
|
||||
description: "ID of the DAG run being executed"
|
||||
$ref: "#/components/schemas/DAGRunId"
|
||||
dagName:
|
||||
type: string
|
||||
description: "Name of the DAG being executed"
|
||||
$ref: "#/components/schemas/DAGName"
|
||||
startedAt:
|
||||
type: string
|
||||
description: "RFC3339 timestamp when the task started"
|
||||
rootDagRunName:
|
||||
type: string
|
||||
description: "Name of the root DAG run"
|
||||
$ref: "#/components/schemas/DAGName"
|
||||
rootDagRunId:
|
||||
type: string
|
||||
description: "ID of the root DAG run"
|
||||
$ref: "#/components/schemas/DAGRunId"
|
||||
parentDagRunName:
|
||||
type: string
|
||||
description: "Name of the parent DAG run"
|
||||
$ref: "#/components/schemas/DAGName"
|
||||
parentDagRunId:
|
||||
type: string
|
||||
description: "ID of the parent DAG run"
|
||||
$ref: "#/components/schemas/DAGRunId"
|
||||
required:
|
||||
- dagRunId
|
||||
- dagName
|
||||
@ -3415,6 +3739,118 @@ components:
|
||||
required:
|
||||
- users
|
||||
|
||||
# ============================================================================
|
||||
APIKey:
|
||||
type: object
|
||||
description: "API key information"
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: "Unique identifier"
|
||||
name:
|
||||
type: string
|
||||
description: "Human-readable name"
|
||||
description:
|
||||
type: string
|
||||
description: "Purpose description"
|
||||
role:
|
||||
$ref: "#/components/schemas/UserRole"
|
||||
keyPrefix:
|
||||
type: string
|
||||
description: "First 8 characters for identification"
|
||||
createdAt:
|
||||
type: string
|
||||
format: date-time
|
||||
description: "Creation timestamp"
|
||||
updatedAt:
|
||||
type: string
|
||||
format: date-time
|
||||
description: "Last update timestamp"
|
||||
createdBy:
|
||||
type: string
|
||||
description: "Creator user ID"
|
||||
lastUsedAt:
|
||||
type: string
|
||||
format: date-time
|
||||
nullable: true
|
||||
description: "Last authentication timestamp"
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
- role
|
||||
- keyPrefix
|
||||
- createdAt
|
||||
- updatedAt
|
||||
- createdBy
|
||||
|
||||
APIKeyResponse:
|
||||
type: object
|
||||
description: "API key response"
|
||||
properties:
|
||||
apiKey:
|
||||
$ref: "#/components/schemas/APIKey"
|
||||
required:
|
||||
- apiKey
|
||||
|
||||
APIKeysListResponse:
|
||||
type: object
|
||||
description: "List of API keys"
|
||||
properties:
|
||||
apiKeys:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/APIKey"
|
||||
required:
|
||||
- apiKeys
|
||||
|
||||
CreateAPIKeyRequest:
|
||||
type: object
|
||||
description: "Create API key request"
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
minLength: 1
|
||||
maxLength: 100
|
||||
description: "Human-readable name"
|
||||
description:
|
||||
type: string
|
||||
maxLength: 500
|
||||
description: "Purpose description"
|
||||
role:
|
||||
$ref: "#/components/schemas/UserRole"
|
||||
required:
|
||||
- name
|
||||
- role
|
||||
|
||||
CreateAPIKeyResponse:
|
||||
type: object
|
||||
description: "Create API key response"
|
||||
properties:
|
||||
apiKey:
|
||||
$ref: "#/components/schemas/APIKey"
|
||||
key:
|
||||
type: string
|
||||
description: "Full key secret, only returned once"
|
||||
required:
|
||||
- apiKey
|
||||
- key
|
||||
|
||||
UpdateAPIKeyRequest:
|
||||
type: object
|
||||
description: "Update API key request"
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
minLength: 1
|
||||
maxLength: 100
|
||||
description: "New name"
|
||||
description:
|
||||
type: string
|
||||
maxLength: 500
|
||||
description: "New description"
|
||||
role:
|
||||
$ref: "#/components/schemas/UserRole"
|
||||
|
||||
SuccessResponse:
|
||||
type: object
|
||||
description: "Generic success response"
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 95 KiB After Width: | Height: | Size: 60 KiB |
@ -45,6 +45,7 @@ func init() {
|
||||
rootCmd.AddCommand(cmd.Retry())
|
||||
rootCmd.AddCommand(cmd.StartAll())
|
||||
rootCmd.AddCommand(cmd.Migrate())
|
||||
rootCmd.AddCommand(cmd.Cleanup())
|
||||
|
||||
config.Version = version
|
||||
}
|
||||
|
||||
115
internal/auth/apikey.go
Normal file
115
internal/auth/apikey.go
Normal file
@ -0,0 +1,115 @@
|
||||
// Copyright (C) 2024 Yota Hamada
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// APIKey represents a standalone API key in the system.
|
||||
// API keys are independent entities with their own role assignment,
|
||||
// enabling programmatic access with fine-grained permissions.
|
||||
type APIKey struct {
|
||||
// ID is the unique identifier for the API key (UUID).
|
||||
ID string `json:"id"`
|
||||
// Name is a human-readable name for the API key (required).
|
||||
Name string `json:"name"`
|
||||
// Description is an optional description of the API key's purpose.
|
||||
Description string `json:"description,omitempty"`
|
||||
// Role determines the API key's permissions.
|
||||
Role Role `json:"role"`
|
||||
// KeyHash is the bcrypt hash of the API key secret.
|
||||
// Excluded from JSON serialization for security.
|
||||
KeyHash string `json:"-"`
|
||||
// KeyPrefix stores the first 8 characters of the key for identification.
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
// CreatedAt is the timestamp when the API key was created.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// UpdatedAt is the timestamp when the API key was last modified.
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
// CreatedBy is the user ID of the admin who created the API key.
|
||||
CreatedBy string `json:"created_by"`
|
||||
// LastUsedAt is the timestamp when the API key was last used for authentication.
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
}
|
||||
|
||||
// NewAPIKey creates an APIKey with a new UUID and sets CreatedAt and UpdatedAt to the current UTC time.
|
||||
// It validates that required fields are not empty and the role is valid.
|
||||
// Returns an error if validation fails.
|
||||
func NewAPIKey(name, description string, role Role, keyHash, keyPrefix, createdBy string) (*APIKey, error) {
|
||||
if name == "" {
|
||||
return nil, ErrInvalidAPIKeyName
|
||||
}
|
||||
if keyHash == "" {
|
||||
return nil, ErrInvalidAPIKeyHash
|
||||
}
|
||||
if !role.Valid() {
|
||||
return nil, ErrInvalidRole
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
return &APIKey{
|
||||
ID: uuid.New().String(),
|
||||
Name: name,
|
||||
Description: description,
|
||||
Role: role,
|
||||
KeyHash: keyHash,
|
||||
KeyPrefix: keyPrefix,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CreatedBy: createdBy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// APIKeyForStorage is used for JSON serialization to persistent storage.
|
||||
// It includes the key hash which is excluded from the regular APIKey JSON.
|
||||
type APIKeyForStorage struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Role Role `json:"role"`
|
||||
KeyHash string `json:"key_hash"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
}
|
||||
|
||||
// ToStorage converts an APIKey to APIKeyForStorage for persistence.
|
||||
// NOTE: When adding new fields to APIKey or APIKeyForStorage, ensure both
|
||||
// ToStorage and ToAPIKey are updated to maintain field synchronization.
|
||||
func (k *APIKey) ToStorage() *APIKeyForStorage {
|
||||
return &APIKeyForStorage{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
Description: k.Description,
|
||||
Role: k.Role,
|
||||
KeyHash: k.KeyHash,
|
||||
KeyPrefix: k.KeyPrefix,
|
||||
CreatedAt: k.CreatedAt,
|
||||
UpdatedAt: k.UpdatedAt,
|
||||
CreatedBy: k.CreatedBy,
|
||||
LastUsedAt: k.LastUsedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// ToAPIKey converts APIKeyForStorage back to APIKey.
|
||||
// NOTE: When adding new fields to APIKey or APIKeyForStorage, ensure both
|
||||
// ToStorage and ToAPIKey are updated to maintain field synchronization.
|
||||
func (s *APIKeyForStorage) ToAPIKey() *APIKey {
|
||||
return &APIKey{
|
||||
ID: s.ID,
|
||||
Name: s.Name,
|
||||
Description: s.Description,
|
||||
Role: s.Role,
|
||||
KeyHash: s.KeyHash,
|
||||
KeyPrefix: s.KeyPrefix,
|
||||
CreatedAt: s.CreatedAt,
|
||||
UpdatedAt: s.UpdatedAt,
|
||||
CreatedBy: s.CreatedBy,
|
||||
LastUsedAt: s.LastUsedAt,
|
||||
}
|
||||
}
|
||||
294
internal/auth/apikey_test.go
Normal file
294
internal/auth/apikey_test.go
Normal file
@ -0,0 +1,294 @@
|
||||
// Copyright (C) 2024 Yota Hamada
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewAPIKey(t *testing.T) {
|
||||
before := time.Now().UTC()
|
||||
key, err := NewAPIKey("test-key", "Test description", RoleManager, "hash123", "dagu_tes", "creator-id")
|
||||
after := time.Now().UTC()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, key.ID, "ID should be generated")
|
||||
assert.Equal(t, "test-key", key.Name)
|
||||
assert.Equal(t, "Test description", key.Description)
|
||||
assert.Equal(t, RoleManager, key.Role)
|
||||
assert.Equal(t, "hash123", key.KeyHash)
|
||||
assert.Equal(t, "dagu_tes", key.KeyPrefix)
|
||||
assert.Equal(t, "creator-id", key.CreatedBy)
|
||||
assert.Nil(t, key.LastUsedAt)
|
||||
|
||||
// CreatedAt and UpdatedAt should be set to approximately now
|
||||
assert.True(t, !key.CreatedAt.Before(before), "CreatedAt should be >= before")
|
||||
assert.True(t, !key.CreatedAt.After(after), "CreatedAt should be <= after")
|
||||
assert.Equal(t, key.CreatedAt, key.UpdatedAt, "CreatedAt and UpdatedAt should be equal initially")
|
||||
}
|
||||
|
||||
func TestNewAPIKey_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyName string
|
||||
description string
|
||||
role Role
|
||||
keyHash string
|
||||
keyPrefix string
|
||||
createdBy string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty name returns error",
|
||||
keyName: "",
|
||||
description: "desc",
|
||||
role: RoleViewer,
|
||||
keyHash: "hash",
|
||||
keyPrefix: "prefix",
|
||||
createdBy: "creator",
|
||||
wantErr: ErrInvalidAPIKeyName,
|
||||
},
|
||||
{
|
||||
name: "empty key hash returns error",
|
||||
keyName: "test-key",
|
||||
description: "desc",
|
||||
role: RoleViewer,
|
||||
keyHash: "",
|
||||
keyPrefix: "prefix",
|
||||
createdBy: "creator",
|
||||
wantErr: ErrInvalidAPIKeyHash,
|
||||
},
|
||||
{
|
||||
name: "invalid role returns error",
|
||||
keyName: "test-key",
|
||||
description: "desc",
|
||||
role: Role("invalid"),
|
||||
keyHash: "hash",
|
||||
keyPrefix: "prefix",
|
||||
createdBy: "creator",
|
||||
wantErr: ErrInvalidRole,
|
||||
},
|
||||
{
|
||||
name: "valid input returns no error",
|
||||
keyName: "test-key",
|
||||
description: "desc",
|
||||
role: RoleViewer,
|
||||
keyHash: "hash",
|
||||
keyPrefix: "prefix",
|
||||
createdBy: "creator",
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
key, err := NewAPIKey(tt.keyName, tt.description, tt.role, tt.keyHash, tt.keyPrefix, tt.createdBy)
|
||||
if tt.wantErr != nil {
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Nil(t, key)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKey_ToStorage(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
lastUsed := now.Add(-time.Hour)
|
||||
key := &APIKey{
|
||||
ID: "key-id",
|
||||
Name: "test-key",
|
||||
Description: "Test description",
|
||||
Role: RoleAdmin,
|
||||
KeyHash: "hash123",
|
||||
KeyPrefix: "dagu_tes",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CreatedBy: "creator-id",
|
||||
LastUsedAt: &lastUsed,
|
||||
}
|
||||
|
||||
storage := key.ToStorage()
|
||||
|
||||
assert.Equal(t, key.ID, storage.ID)
|
||||
assert.Equal(t, key.Name, storage.Name)
|
||||
assert.Equal(t, key.Description, storage.Description)
|
||||
assert.Equal(t, key.Role, storage.Role)
|
||||
assert.Equal(t, key.KeyHash, storage.KeyHash)
|
||||
assert.Equal(t, key.KeyPrefix, storage.KeyPrefix)
|
||||
assert.Equal(t, key.CreatedAt, storage.CreatedAt)
|
||||
assert.Equal(t, key.UpdatedAt, storage.UpdatedAt)
|
||||
assert.Equal(t, key.CreatedBy, storage.CreatedBy)
|
||||
require.NotNil(t, storage.LastUsedAt)
|
||||
assert.Equal(t, *key.LastUsedAt, *storage.LastUsedAt)
|
||||
}
|
||||
|
||||
func TestAPIKeyForStorage_ToAPIKey(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
lastUsed := now.Add(-time.Hour)
|
||||
storage := &APIKeyForStorage{
|
||||
ID: "key-id",
|
||||
Name: "test-key",
|
||||
Description: "Test description",
|
||||
Role: RoleViewer,
|
||||
KeyHash: "hash456",
|
||||
KeyPrefix: "dagu_xyz",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CreatedBy: "admin-user",
|
||||
LastUsedAt: &lastUsed,
|
||||
}
|
||||
|
||||
key := storage.ToAPIKey()
|
||||
|
||||
assert.Equal(t, storage.ID, key.ID)
|
||||
assert.Equal(t, storage.Name, key.Name)
|
||||
assert.Equal(t, storage.Description, key.Description)
|
||||
assert.Equal(t, storage.Role, key.Role)
|
||||
assert.Equal(t, storage.KeyHash, key.KeyHash)
|
||||
assert.Equal(t, storage.KeyPrefix, key.KeyPrefix)
|
||||
assert.Equal(t, storage.CreatedAt, key.CreatedAt)
|
||||
assert.Equal(t, storage.UpdatedAt, key.UpdatedAt)
|
||||
assert.Equal(t, storage.CreatedBy, key.CreatedBy)
|
||||
require.NotNil(t, key.LastUsedAt)
|
||||
assert.Equal(t, *storage.LastUsedAt, *key.LastUsedAt)
|
||||
}
|
||||
|
||||
func TestAPIKey_ToStorage_ToAPIKey_Roundtrip(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
original := &APIKey{
|
||||
ID: "key-id",
|
||||
Name: "roundtrip-key",
|
||||
Description: "Roundtrip test",
|
||||
Role: RoleOperator,
|
||||
KeyHash: "secret-hash",
|
||||
KeyPrefix: "dagu_rnd",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CreatedBy: "creator",
|
||||
}
|
||||
|
||||
// Convert to storage and back
|
||||
storage := original.ToStorage()
|
||||
recovered := storage.ToAPIKey()
|
||||
|
||||
assert.Equal(t, original.ID, recovered.ID)
|
||||
assert.Equal(t, original.Name, recovered.Name)
|
||||
assert.Equal(t, original.Description, recovered.Description)
|
||||
assert.Equal(t, original.Role, recovered.Role)
|
||||
assert.Equal(t, original.KeyHash, recovered.KeyHash)
|
||||
assert.Equal(t, original.KeyPrefix, recovered.KeyPrefix)
|
||||
assert.Equal(t, original.CreatedAt, recovered.CreatedAt)
|
||||
assert.Equal(t, original.UpdatedAt, recovered.UpdatedAt)
|
||||
assert.Equal(t, original.CreatedBy, recovered.CreatedBy)
|
||||
assert.Equal(t, original.LastUsedAt, recovered.LastUsedAt)
|
||||
}
|
||||
|
||||
func TestAPIKey_JSONSerialization(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second) // Truncate for JSON round-trip
|
||||
key := &APIKey{
|
||||
ID: "key-id",
|
||||
Name: "json-key",
|
||||
Description: "JSON test",
|
||||
Role: RoleAdmin,
|
||||
KeyHash: "should-be-excluded",
|
||||
KeyPrefix: "dagu_jsn",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CreatedBy: "creator",
|
||||
}
|
||||
|
||||
// Serialize to JSON
|
||||
data, err := json.Marshal(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// KeyHash should NOT be in the JSON (json:"-" tag)
|
||||
jsonStr := string(data)
|
||||
assert.NotContains(t, jsonStr, "should-be-excluded")
|
||||
assert.NotContains(t, jsonStr, "key_hash")
|
||||
|
||||
// Deserialize back
|
||||
var recovered APIKey
|
||||
err = json.Unmarshal(data, &recovered)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, key.ID, recovered.ID)
|
||||
assert.Equal(t, key.Name, recovered.Name)
|
||||
assert.Equal(t, key.Description, recovered.Description)
|
||||
assert.Equal(t, key.Role, recovered.Role)
|
||||
assert.Equal(t, key.KeyPrefix, recovered.KeyPrefix)
|
||||
assert.Empty(t, recovered.KeyHash, "KeyHash should not be deserialized")
|
||||
}
|
||||
|
||||
func TestAPIKeyForStorage_JSONSerialization(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second) // Truncate for JSON round-trip
|
||||
storage := &APIKeyForStorage{
|
||||
ID: "key-id",
|
||||
Name: "storage-key",
|
||||
Description: "Storage test",
|
||||
Role: RoleManager,
|
||||
KeyHash: "included-hash",
|
||||
KeyPrefix: "dagu_str",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
CreatedBy: "admin",
|
||||
}
|
||||
|
||||
// Serialize to JSON
|
||||
data, err := json.Marshal(storage)
|
||||
require.NoError(t, err)
|
||||
|
||||
// KeyHash SHOULD be in the JSON for storage
|
||||
jsonStr := string(data)
|
||||
assert.Contains(t, jsonStr, "included-hash")
|
||||
assert.Contains(t, jsonStr, "key_hash")
|
||||
|
||||
// Deserialize back
|
||||
var recovered APIKeyForStorage
|
||||
err = json.Unmarshal(data, &recovered)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, storage.ID, recovered.ID)
|
||||
assert.Equal(t, storage.Name, recovered.Name)
|
||||
assert.Equal(t, storage.Role, recovered.Role)
|
||||
assert.Equal(t, storage.KeyHash, recovered.KeyHash)
|
||||
}
|
||||
|
||||
func TestAPIKey_NilLastUsedAt(t *testing.T) {
|
||||
key := &APIKey{
|
||||
ID: "key-id",
|
||||
Name: "nil-lastused",
|
||||
Role: RoleViewer,
|
||||
KeyHash: "hash",
|
||||
KeyPrefix: "dagu_nil",
|
||||
CreatedBy: "creator",
|
||||
}
|
||||
|
||||
// LastUsedAt is nil by default
|
||||
assert.Nil(t, key.LastUsedAt)
|
||||
|
||||
// Convert to storage and back
|
||||
storage := key.ToStorage()
|
||||
assert.Nil(t, storage.LastUsedAt)
|
||||
|
||||
recovered := storage.ToAPIKey()
|
||||
assert.Nil(t, recovered.LastUsedAt)
|
||||
}
|
||||
|
||||
func TestNewAPIKey_GeneratesUniqueIDs(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
key, err := NewAPIKey("test", "", RoleViewer, "hash", "prefix", "creator")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ids[key.ID], "ID should be unique")
|
||||
ids[key.ID] = true
|
||||
}
|
||||
}
|
||||
87
internal/auth/context_test.go
Normal file
87
internal/auth/context_test.go
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright (C) 2024 Yota Hamada
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWithUser(t *testing.T) {
|
||||
t.Run("stores user in context", func(t *testing.T) {
|
||||
user := NewUser("testuser", "hash", RoleAdmin)
|
||||
ctx := WithUser(context.Background(), user)
|
||||
|
||||
retrieved, ok := UserFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, user.ID, retrieved.ID)
|
||||
assert.Equal(t, user.Username, retrieved.Username)
|
||||
assert.Equal(t, user.Role, retrieved.Role)
|
||||
})
|
||||
|
||||
t.Run("stores nil user in context", func(t *testing.T) {
|
||||
ctx := WithUser(context.Background(), nil)
|
||||
|
||||
retrieved, ok := UserFromContext(ctx)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, retrieved)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserFromContext(t *testing.T) {
|
||||
t.Run("returns user when present", func(t *testing.T) {
|
||||
user := NewUser("admin", "hash", RoleManager)
|
||||
ctx := WithUser(context.Background(), user)
|
||||
|
||||
retrieved, ok := UserFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, user, retrieved)
|
||||
})
|
||||
|
||||
t.Run("returns false when user not present", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
retrieved, ok := UserFromContext(ctx)
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, retrieved)
|
||||
})
|
||||
|
||||
t.Run("returns false when wrong type in context", func(t *testing.T) {
|
||||
// Manually add wrong type to context with the same key pattern
|
||||
ctx := context.WithValue(context.Background(), contextKey("auth_user"), "not a user")
|
||||
|
||||
retrieved, ok := UserFromContext(ctx)
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, retrieved)
|
||||
})
|
||||
|
||||
t.Run("preserves user through context chain", func(t *testing.T) {
|
||||
user := NewUser("chainuser", "hash", RoleViewer)
|
||||
ctx := WithUser(context.Background(), user)
|
||||
|
||||
// Add more values to the context chain using a typed key
|
||||
type otherKey struct{}
|
||||
ctx = context.WithValue(ctx, otherKey{}, "other_value")
|
||||
|
||||
retrieved, ok := UserFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, user.Username, retrieved.Username)
|
||||
})
|
||||
|
||||
t.Run("latest user overwrites previous", func(t *testing.T) {
|
||||
user1 := NewUser("user1", "hash1", RoleAdmin)
|
||||
user2 := NewUser("user2", "hash2", RoleViewer)
|
||||
|
||||
ctx := WithUser(context.Background(), user1)
|
||||
ctx = WithUser(ctx, user2)
|
||||
|
||||
retrieved, ok := UserFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, user2.Username, retrieved.Username)
|
||||
assert.Equal(t, user2.Role, retrieved.Role)
|
||||
})
|
||||
}
|
||||
@ -134,3 +134,25 @@ func TestAllRoles(t *testing.T) {
|
||||
t.Error("AllRoles() returned a reference to internal state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRole_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
role Role
|
||||
want string
|
||||
}{
|
||||
{RoleAdmin, "admin"},
|
||||
{RoleManager, "manager"},
|
||||
{RoleOperator, "operator"},
|
||||
{RoleViewer, "viewer"},
|
||||
{Role("custom"), "custom"},
|
||||
{Role(""), ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
if got := tt.role.String(); got != tt.want {
|
||||
t.Errorf("Role(%q).String() = %v, want %v", tt.role, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,6 +21,23 @@ var (
|
||||
ErrInvalidUserID = errors.New("invalid user ID")
|
||||
)
|
||||
|
||||
// Common errors for API key store operations.
|
||||
var (
|
||||
// ErrAPIKeyNotFound is returned when an API key cannot be found.
|
||||
ErrAPIKeyNotFound = errors.New("API key not found")
|
||||
// ErrAPIKeyAlreadyExists is returned when attempting to create an API key
|
||||
// with a name that already exists.
|
||||
ErrAPIKeyAlreadyExists = errors.New("API key already exists")
|
||||
// ErrInvalidAPIKeyName is returned when the API key name is invalid.
|
||||
ErrInvalidAPIKeyName = errors.New("invalid API key name")
|
||||
// ErrInvalidAPIKeyID is returned when the API key ID is invalid.
|
||||
ErrInvalidAPIKeyID = errors.New("invalid API key ID")
|
||||
// ErrInvalidAPIKeyHash is returned when the API key hash is empty.
|
||||
ErrInvalidAPIKeyHash = errors.New("invalid API key hash")
|
||||
// ErrInvalidRole is returned when the role is not a valid role.
|
||||
ErrInvalidRole = errors.New("invalid role")
|
||||
)
|
||||
|
||||
// UserStore defines the interface for user persistence operations.
|
||||
// Implementations must be safe for concurrent use.
|
||||
type UserStore interface {
|
||||
@ -50,3 +67,30 @@ type UserStore interface {
|
||||
// Count returns the total number of users.
|
||||
Count(ctx context.Context) (int64, error)
|
||||
}
|
||||
|
||||
// APIKeyStore defines the interface for API key persistence operations.
|
||||
// Implementations must be safe for concurrent use.
|
||||
type APIKeyStore interface {
|
||||
// Create stores a new API key.
|
||||
// Returns ErrAPIKeyAlreadyExists if an API key with the same name exists.
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
|
||||
// GetByID retrieves an API key by its unique ID.
|
||||
// Returns ErrAPIKeyNotFound if the API key does not exist.
|
||||
GetByID(ctx context.Context, id string) (*APIKey, error)
|
||||
|
||||
// List returns all API keys in the store.
|
||||
List(ctx context.Context) ([]*APIKey, error)
|
||||
|
||||
// Update modifies an existing API key.
|
||||
// Returns ErrAPIKeyNotFound if the API key does not exist.
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
|
||||
// Delete removes an API key by its ID.
|
||||
// Returns ErrAPIKeyNotFound if the API key does not exist.
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// UpdateLastUsed updates the LastUsedAt timestamp for an API key.
|
||||
// This is called when the API key is used for authentication.
|
||||
UpdateLastUsed(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
230
internal/auth/user_test.go
Normal file
230
internal/auth/user_test.go
Normal file
@ -0,0 +1,230 @@
|
||||
// Copyright (C) 2024 Yota Hamada
|
||||
// SPDX-License-Identifier: GPL-3.0-or-later
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
t.Run("creates user with all fields", func(t *testing.T) {
|
||||
before := time.Now().UTC()
|
||||
user := NewUser("testuser", "hashedpassword", RoleManager)
|
||||
after := time.Now().UTC()
|
||||
|
||||
assert.NotEmpty(t, user.ID)
|
||||
assert.Equal(t, "testuser", user.Username)
|
||||
assert.Equal(t, "hashedpassword", user.PasswordHash)
|
||||
assert.Equal(t, RoleManager, user.Role)
|
||||
|
||||
assert.True(t, !user.CreatedAt.Before(before))
|
||||
assert.True(t, !user.CreatedAt.After(after))
|
||||
assert.Equal(t, user.CreatedAt, user.UpdatedAt)
|
||||
})
|
||||
|
||||
t.Run("generates unique IDs", func(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
user := NewUser("user", "hash", RoleViewer)
|
||||
assert.False(t, ids[user.ID], "ID should be unique")
|
||||
ids[user.ID] = true
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("supports all roles", func(t *testing.T) {
|
||||
roles := []Role{RoleAdmin, RoleManager, RoleOperator, RoleViewer}
|
||||
for _, role := range roles {
|
||||
user := NewUser("user", "hash", role)
|
||||
assert.Equal(t, role, user.Role)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allows empty username", func(t *testing.T) {
|
||||
user := NewUser("", "hash", RoleViewer)
|
||||
assert.Empty(t, user.Username)
|
||||
assert.NotEmpty(t, user.ID)
|
||||
})
|
||||
|
||||
t.Run("allows empty password hash", func(t *testing.T) {
|
||||
user := NewUser("user", "", RoleViewer)
|
||||
assert.Empty(t, user.PasswordHash)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ToStorage(t *testing.T) {
|
||||
t.Run("converts all fields", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
user := &User{
|
||||
ID: "user-123",
|
||||
Username: "admin",
|
||||
PasswordHash: "secret-hash",
|
||||
Role: RoleAdmin,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now.Add(time.Hour),
|
||||
}
|
||||
|
||||
storage := user.ToStorage()
|
||||
|
||||
assert.Equal(t, user.ID, storage.ID)
|
||||
assert.Equal(t, user.Username, storage.Username)
|
||||
assert.Equal(t, user.PasswordHash, storage.PasswordHash)
|
||||
assert.Equal(t, user.Role, storage.Role)
|
||||
assert.Equal(t, user.CreatedAt, storage.CreatedAt)
|
||||
assert.Equal(t, user.UpdatedAt, storage.UpdatedAt)
|
||||
})
|
||||
|
||||
t.Run("handles empty fields", func(t *testing.T) {
|
||||
user := &User{}
|
||||
storage := user.ToStorage()
|
||||
|
||||
assert.Empty(t, storage.ID)
|
||||
assert.Empty(t, storage.Username)
|
||||
assert.Empty(t, storage.PasswordHash)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserForStorage_ToUser(t *testing.T) {
|
||||
t.Run("converts all fields", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
storage := &UserForStorage{
|
||||
ID: "storage-456",
|
||||
Username: "operator",
|
||||
PasswordHash: "stored-hash",
|
||||
Role: RoleOperator,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now.Add(2 * time.Hour),
|
||||
}
|
||||
|
||||
user := storage.ToUser()
|
||||
|
||||
assert.Equal(t, storage.ID, user.ID)
|
||||
assert.Equal(t, storage.Username, user.Username)
|
||||
assert.Equal(t, storage.PasswordHash, user.PasswordHash)
|
||||
assert.Equal(t, storage.Role, user.Role)
|
||||
assert.Equal(t, storage.CreatedAt, user.CreatedAt)
|
||||
assert.Equal(t, storage.UpdatedAt, user.UpdatedAt)
|
||||
})
|
||||
|
||||
t.Run("handles empty fields", func(t *testing.T) {
|
||||
storage := &UserForStorage{}
|
||||
user := storage.ToUser()
|
||||
|
||||
assert.Empty(t, user.ID)
|
||||
assert.Empty(t, user.Username)
|
||||
assert.Empty(t, user.PasswordHash)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_StorageRoundtrip(t *testing.T) {
|
||||
t.Run("preserves all fields through roundtrip", func(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
original := &User{
|
||||
ID: "roundtrip-id",
|
||||
Username: "roundtrip-user",
|
||||
PasswordHash: "roundtrip-hash",
|
||||
Role: RoleManager,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now.Add(time.Minute),
|
||||
}
|
||||
|
||||
storage := original.ToStorage()
|
||||
recovered := storage.ToUser()
|
||||
|
||||
assert.Equal(t, original.ID, recovered.ID)
|
||||
assert.Equal(t, original.Username, recovered.Username)
|
||||
assert.Equal(t, original.PasswordHash, recovered.PasswordHash)
|
||||
assert.Equal(t, original.Role, recovered.Role)
|
||||
assert.Equal(t, original.CreatedAt, recovered.CreatedAt)
|
||||
assert.Equal(t, original.UpdatedAt, recovered.UpdatedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_JSONSerialization(t *testing.T) {
|
||||
t.Run("excludes password hash", func(t *testing.T) {
|
||||
user := &User{
|
||||
ID: "json-id",
|
||||
Username: "jsonuser",
|
||||
PasswordHash: "should-be-excluded",
|
||||
Role: RoleViewer,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
jsonStr := string(data)
|
||||
assert.NotContains(t, jsonStr, "should-be-excluded")
|
||||
assert.NotContains(t, jsonStr, "password_hash")
|
||||
assert.NotContains(t, jsonStr, "PasswordHash")
|
||||
})
|
||||
|
||||
t.Run("includes other fields", func(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
user := &User{
|
||||
ID: "json-id",
|
||||
Username: "jsonuser",
|
||||
Role: RoleAdmin,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
var recovered User
|
||||
err = json.Unmarshal(data, &recovered)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, user.ID, recovered.ID)
|
||||
assert.Equal(t, user.Username, recovered.Username)
|
||||
assert.Equal(t, user.Role, recovered.Role)
|
||||
assert.Empty(t, recovered.PasswordHash)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserForStorage_JSONSerialization(t *testing.T) {
|
||||
t.Run("includes password hash", func(t *testing.T) {
|
||||
storage := &UserForStorage{
|
||||
ID: "storage-id",
|
||||
Username: "storageuser",
|
||||
PasswordHash: "included-hash",
|
||||
Role: RoleManager,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(storage)
|
||||
require.NoError(t, err)
|
||||
|
||||
jsonStr := string(data)
|
||||
assert.Contains(t, jsonStr, "included-hash")
|
||||
assert.Contains(t, jsonStr, "password_hash")
|
||||
})
|
||||
|
||||
t.Run("roundtrip preserves all fields", func(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
original := &UserForStorage{
|
||||
ID: "storage-rt-id",
|
||||
Username: "storage-rt-user",
|
||||
PasswordHash: "storage-rt-hash",
|
||||
Role: RoleOperator,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
var recovered UserForStorage
|
||||
err = json.Unmarshal(data, &recovered)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original.ID, recovered.ID)
|
||||
assert.Equal(t, original.Username, recovered.Username)
|
||||
assert.Equal(t, original.PasswordHash, recovered.PasswordHash)
|
||||
assert.Equal(t, original.Role, recovered.Role)
|
||||
})
|
||||
}
|
||||
134
internal/cmd/cleanup.go
Normal file
134
internal/cmd/cleanup.go
Normal file
@ -0,0 +1,134 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/execution"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Cleanup creates and returns a cobra command for removing old DAG run history.
|
||||
func Cleanup() *cobra.Command {
|
||||
return NewCommand(
|
||||
&cobra.Command{
|
||||
Use: "cleanup [flags] <DAG name>",
|
||||
Short: "Remove old DAG run history",
|
||||
Long: `Remove old DAG run history for a specified DAG.
|
||||
|
||||
By default, removes all history except for currently active runs.
|
||||
Use --retention-days to keep recent history.
|
||||
|
||||
Active runs are never deleted for safety.
|
||||
|
||||
Examples:
|
||||
dagu cleanup my-workflow # Delete all history (with confirmation)
|
||||
dagu cleanup --retention-days 30 my-workflow # Keep last 30 days
|
||||
dagu cleanup --dry-run my-workflow # Preview what would be deleted
|
||||
dagu cleanup -y my-workflow # Skip confirmation
|
||||
`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
},
|
||||
cleanupFlags,
|
||||
runCleanup,
|
||||
)
|
||||
}
|
||||
|
||||
var cleanupFlags = []commandLineFlag{
|
||||
retentionDaysFlag,
|
||||
dryRunFlag,
|
||||
yesFlag,
|
||||
}
|
||||
|
||||
func runCleanup(ctx *Context, args []string) error {
|
||||
dagName := args[0]
|
||||
|
||||
// Parse retention days (flags are string-based in this codebase)
|
||||
retentionStr, err := ctx.StringParam("retention-days")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get retention-days: %w", err)
|
||||
}
|
||||
retentionDays, err := strconv.Atoi(retentionStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid retention-days value %q: must be a non-negative integer", retentionStr)
|
||||
}
|
||||
|
||||
// Reject negative retention (clearer error than silent no-op)
|
||||
if retentionDays < 0 {
|
||||
return fmt.Errorf("retention-days cannot be negative (got %d)", retentionDays)
|
||||
}
|
||||
|
||||
// Get boolean flags
|
||||
dryRun, _ := ctx.Command.Flags().GetBool("dry-run")
|
||||
skipConfirm, _ := ctx.Command.Flags().GetBool("yes")
|
||||
|
||||
// Build description message
|
||||
var actionDesc string
|
||||
if retentionDays == 0 {
|
||||
actionDesc = fmt.Sprintf("all history for DAG %q", dagName)
|
||||
} else {
|
||||
actionDesc = fmt.Sprintf("history older than %d days for DAG %q", retentionDays, dagName)
|
||||
}
|
||||
|
||||
// Build options for RemoveOldDAGRuns
|
||||
var opts []execution.RemoveOldDAGRunsOption
|
||||
if dryRun {
|
||||
opts = append(opts, execution.WithDryRun())
|
||||
}
|
||||
|
||||
// Dry run mode - show what would be deleted
|
||||
if dryRun {
|
||||
runIDs, err := ctx.DAGRunStore.RemoveOldDAGRuns(ctx, dagName, retentionDays, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check history for %q: %w", dagName, err)
|
||||
}
|
||||
|
||||
if len(runIDs) == 0 {
|
||||
fmt.Printf("Dry run: No runs to delete for DAG %q\n", dagName)
|
||||
} else {
|
||||
fmt.Printf("Dry run: Would delete %d run(s) for DAG %q:\n", len(runIDs), dagName)
|
||||
for _, runID := range runIDs {
|
||||
fmt.Printf(" - %s\n", runID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Confirmation prompt (unless --yes or --quiet)
|
||||
if !skipConfirm && !ctx.Quiet {
|
||||
fmt.Printf("This will delete %s.\n", actionDesc)
|
||||
fmt.Println("Active runs will be preserved.")
|
||||
fmt.Print("Continue? [y/N]: ")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
response, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read user input: %w", err)
|
||||
}
|
||||
response = strings.TrimSpace(strings.ToLower(response))
|
||||
|
||||
if response != "y" && response != "yes" {
|
||||
fmt.Println("Cancelled.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute cleanup using the existing DAGRunStore method
|
||||
runIDs, err := ctx.DAGRunStore.RemoveOldDAGRuns(ctx, dagName, retentionDays, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cleanup history for %q: %w", dagName, err)
|
||||
}
|
||||
|
||||
if !ctx.Quiet {
|
||||
if len(runIDs) == 0 {
|
||||
fmt.Printf("No runs to delete for DAG %q\n", dagName)
|
||||
} else {
|
||||
fmt.Printf("Successfully removed %d run(s) for DAG %q\n", len(runIDs), dagName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
280
internal/cmd/cleanup_test.go
Normal file
280
internal/cmd/cleanup_test.go
Normal file
@ -0,0 +1,280 @@
|
||||
package cmd_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/cmd"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/execution"
|
||||
"github.com/dagu-org/dagu/internal/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCleanupCommand(t *testing.T) {
|
||||
t.Run("DeletesAllHistoryWithRetentionZero", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
// Create a DAG and run it to generate history
|
||||
dag := th.DAG(t, `steps:
|
||||
- name: "1"
|
||||
command: echo "hello"
|
||||
`)
|
||||
// Run the DAG to create history
|
||||
th.RunCommand(t, cmd.Start(), test.CmdTest{
|
||||
Args: []string{"start", dag.Location},
|
||||
})
|
||||
|
||||
// Wait for DAG to complete
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
// Verify history exists
|
||||
dag.AssertDAGRunCount(t, 1)
|
||||
|
||||
// Run cleanup with --yes to skip confirmation
|
||||
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
|
||||
Args: []string{"cleanup", "--yes", dag.Name},
|
||||
})
|
||||
|
||||
// Verify history is deleted
|
||||
dag.AssertDAGRunCount(t, 0)
|
||||
})
|
||||
|
||||
t.Run("PreservesRecentHistoryWithRetentionDays", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
// Create a DAG and run it
|
||||
dag := th.DAG(t, `steps:
|
||||
- name: "1"
|
||||
command: echo "hello"
|
||||
`)
|
||||
// Run the DAG to create history
|
||||
th.RunCommand(t, cmd.Start(), test.CmdTest{
|
||||
Args: []string{"start", dag.Location},
|
||||
})
|
||||
|
||||
// Wait for DAG to complete
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
// Verify history exists
|
||||
dag.AssertDAGRunCount(t, 1)
|
||||
|
||||
// Run cleanup with retention of 30 days (should keep recent history)
|
||||
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
|
||||
Args: []string{"cleanup", "--retention-days", "30", "--yes", dag.Name},
|
||||
})
|
||||
|
||||
// Verify history is still there (it's less than 30 days old)
|
||||
dag.AssertDAGRunCount(t, 1)
|
||||
})
|
||||
|
||||
t.Run("DryRunDoesNotDelete", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
// Create a DAG and run it
|
||||
dag := th.DAG(t, `steps:
|
||||
- name: "1"
|
||||
command: echo "hello"
|
||||
`)
|
||||
// Run the DAG to create history
|
||||
th.RunCommand(t, cmd.Start(), test.CmdTest{
|
||||
Args: []string{"start", dag.Location},
|
||||
})
|
||||
|
||||
// Wait for DAG to complete
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
// Verify history exists
|
||||
dag.AssertDAGRunCount(t, 1)
|
||||
|
||||
// Run cleanup with --dry-run
|
||||
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
|
||||
Args: []string{"cleanup", "--dry-run", dag.Name},
|
||||
})
|
||||
|
||||
// Verify history is still there (dry run should not delete)
|
||||
dag.AssertDAGRunCount(t, 1)
|
||||
})
|
||||
|
||||
t.Run("PreservesActiveRuns", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
// Create a DAG that runs for a while
|
||||
dag := th.DAG(t, `steps:
|
||||
- name: "1"
|
||||
command: sleep 30
|
||||
`)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Start the DAG
|
||||
th.RunCommand(t, cmd.Start(), test.CmdTest{
|
||||
Args: []string{"start", dag.Location},
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Wait for DAG to start running
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
dag.AssertLatestStatus(t, core.Running)
|
||||
|
||||
// Try to cleanup while running (nothing to delete since only active run exists)
|
||||
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
|
||||
Args: []string{"cleanup", "--yes", dag.Name},
|
||||
})
|
||||
|
||||
// Verify the running DAG is still there (should be preserved)
|
||||
dag.AssertLatestStatus(t, core.Running)
|
||||
|
||||
// Stop the DAG
|
||||
th.RunCommand(t, cmd.Stop(), test.CmdTest{
|
||||
Args: []string{"stop", dag.Location},
|
||||
})
|
||||
|
||||
<-done
|
||||
})
|
||||
|
||||
t.Run("RejectsNegativeRetentionDays", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
dag := th.DAG(t, `steps:
|
||||
- name: "1"
|
||||
command: echo "hello"
|
||||
`)
|
||||
|
||||
err := th.RunCommandWithError(t, cmd.Cleanup(), test.CmdTest{
|
||||
Args: []string{"cleanup", "--retention-days", "-1", dag.Name},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot be negative")
|
||||
})
|
||||
|
||||
t.Run("RejectsInvalidRetentionDays", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
dag := th.DAG(t, `steps:
|
||||
- name: "1"
|
||||
command: echo "hello"
|
||||
`)
|
||||
|
||||
err := th.RunCommandWithError(t, cmd.Cleanup(), test.CmdTest{
|
||||
Args: []string{"cleanup", "--retention-days", "abc", dag.Name},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid retention-days")
|
||||
})
|
||||
|
||||
t.Run("RequiresDAGNameArgument", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
err := th.RunCommandWithError(t, cmd.Cleanup(), test.CmdTest{
|
||||
Args: []string{"cleanup", "--yes"},
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "accepts 1 arg")
|
||||
})
|
||||
|
||||
t.Run("SucceedsForNonExistentDAG", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
// Cleanup for a DAG that doesn't exist should succeed silently
|
||||
th.RunCommand(t, cmd.Cleanup(), test.CmdTest{
|
||||
Args: []string{"cleanup", "--yes", "non-existent-dag"},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCleanupCommandDirectStore(t *testing.T) {
|
||||
// Test cleanup using the DAGRunStore directly to verify underlying behavior
|
||||
t.Run("RemoveOldDAGRunsWithStore", func(t *testing.T) {
|
||||
th := test.Setup(t)
|
||||
|
||||
dagName := "test-cleanup-dag"
|
||||
|
||||
// Create old DAG runs directly in the store
|
||||
oldTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
recentTime := time.Now()
|
||||
|
||||
// Create a minimal DAG for the test
|
||||
testDAG := &core.DAG{Name: dagName}
|
||||
|
||||
// Create an old run
|
||||
oldAttempt, err := th.DAGRunStore.CreateAttempt(
|
||||
th.Context,
|
||||
testDAG,
|
||||
oldTime,
|
||||
"old-run-id",
|
||||
execution.NewDAGRunAttemptOptions{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, oldAttempt.Open(th.Context))
|
||||
require.NoError(t, oldAttempt.Write(th.Context, execution.DAGRunStatus{
|
||||
Name: dagName,
|
||||
DAGRunID: "old-run-id",
|
||||
Status: core.Succeeded,
|
||||
}))
|
||||
require.NoError(t, oldAttempt.Close(th.Context))
|
||||
|
||||
// Create a recent run
|
||||
recentAttempt, err := th.DAGRunStore.CreateAttempt(
|
||||
th.Context,
|
||||
testDAG,
|
||||
recentTime,
|
||||
"recent-run-id",
|
||||
execution.NewDAGRunAttemptOptions{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, recentAttempt.Open(th.Context))
|
||||
require.NoError(t, recentAttempt.Write(th.Context, execution.DAGRunStatus{
|
||||
Name: dagName,
|
||||
DAGRunID: "recent-run-id",
|
||||
Status: core.Succeeded,
|
||||
}))
|
||||
require.NoError(t, recentAttempt.Close(th.Context))
|
||||
|
||||
// Manually set old file modification time
|
||||
setOldModTime(t, th.Config.Paths.DAGRunsDir, dagName, "", oldTime)
|
||||
|
||||
// Verify both runs exist
|
||||
runs := th.DAGRunStore.RecentAttempts(th.Context, dagName, 10)
|
||||
require.Len(t, runs, 2)
|
||||
|
||||
// Remove runs older than 7 days
|
||||
removedIDs, err := th.DAGRunStore.RemoveOldDAGRuns(th.Context, dagName, 7)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, removedIDs, 1)
|
||||
|
||||
// Verify old run is deleted, recent run remains
|
||||
runs = th.DAGRunStore.RecentAttempts(th.Context, dagName, 10)
|
||||
require.Len(t, runs, 1)
|
||||
|
||||
status, err := runs[0].ReadStatus(th.Context)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "recent-run-id", status.DAGRunID)
|
||||
})
|
||||
}
|
||||
|
||||
// setOldModTime sets old modification time on DAG run files
|
||||
func setOldModTime(t *testing.T, baseDir, dagName, _ string, modTime time.Time) {
|
||||
t.Helper()
|
||||
|
||||
// Find the run directory
|
||||
dagRunsDir := filepath.Join(baseDir, dagName, "dag-runs")
|
||||
err := filepath.Walk(dagRunsDir, func(path string, _ os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Set mod time on all files and directories
|
||||
return os.Chtimes(path, modTime, modTime)
|
||||
})
|
||||
// Ignore errors if directory doesn't exist
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
t.Logf("Warning: failed to set mod time: %v", err)
|
||||
}
|
||||
}
|
||||
@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
@ -27,6 +28,7 @@ import (
|
||||
"github.com/dagu-org/dagu/internal/persistence/filequeue"
|
||||
"github.com/dagu-org/dagu/internal/persistence/fileserviceregistry"
|
||||
"github.com/dagu-org/dagu/internal/runtime"
|
||||
"github.com/dagu-org/dagu/internal/runtime/transform"
|
||||
"github.com/dagu-org/dagu/internal/service/coordinator"
|
||||
"github.com/dagu-org/dagu/internal/service/frontend"
|
||||
"github.com/dagu-org/dagu/internal/service/resource"
|
||||
@ -355,9 +357,7 @@ func (c *Context) GenLogFileName(dag *core.DAG, dagRunID string) (string, error)
|
||||
|
||||
// NewCommand creates a new command instance with the given cobra command and run function.
|
||||
func NewCommand(cmd *cobra.Command, flags []commandLineFlag, runFunc func(cmd *Context, args []string) error) *cobra.Command {
|
||||
config.WithViperLock(func() {
|
||||
initFlags(cmd, flags...)
|
||||
})
|
||||
initFlags(cmd, flags...)
|
||||
|
||||
cmd.SilenceUsage = true
|
||||
|
||||
@ -490,6 +490,53 @@ func (cfg LogConfig) LogDir() (string, error) {
|
||||
return logDir, nil
|
||||
}
|
||||
|
||||
// RecordEarlyFailure records a failure in the execution history before the DAG has fully started.
|
||||
// This is used for infrastructure errors like singleton conflicts or process acquisition failures.
|
||||
func (c *Context) RecordEarlyFailure(dag *core.DAG, dagRunID string, err error) error {
|
||||
if dag == nil || dagRunID == "" {
|
||||
return fmt.Errorf("DAG and dag-run ID are required to record failure")
|
||||
}
|
||||
|
||||
// 1. Check if a DAGRunAttempt already exists for the given run-id.
|
||||
ref := execution.NewDAGRunRef(dag.Name, dagRunID)
|
||||
attempt, findErr := c.DAGRunStore.FindAttempt(c, ref)
|
||||
if findErr != nil && !errors.Is(findErr, execution.ErrDAGRunIDNotFound) {
|
||||
return fmt.Errorf("failed to check for existing attempt: %w", findErr)
|
||||
}
|
||||
|
||||
if attempt == nil {
|
||||
// 2. Create the attempt if not exists
|
||||
att, createErr := c.DAGRunStore.CreateAttempt(c, dag, time.Now(), dagRunID, execution.NewDAGRunAttemptOptions{})
|
||||
if createErr != nil {
|
||||
return fmt.Errorf("failed to create run to record failure: %w", createErr)
|
||||
}
|
||||
attempt = att
|
||||
}
|
||||
|
||||
// 3. Construct the "Failed" status
|
||||
statusBuilder := transform.NewStatusBuilder(dag)
|
||||
logPath, _ := c.GenLogFileName(dag, dagRunID)
|
||||
status := statusBuilder.Create(dagRunID, core.Failed, 0, time.Now(),
|
||||
transform.WithLogFilePath(logPath),
|
||||
transform.WithFinishedAt(time.Now()),
|
||||
transform.WithError(err.Error()),
|
||||
)
|
||||
|
||||
// 4. Write the status
|
||||
if err := attempt.Open(c); err != nil {
|
||||
return fmt.Errorf("failed to open attempt for recording failure: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = attempt.Close(c)
|
||||
}()
|
||||
|
||||
if err := attempt.Write(c, status); err != nil {
|
||||
return fmt.Errorf("failed to write failed status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogFile constructs the log filename using the prefix, safe DAG name, current timestamp,
|
||||
// and a truncated version of the dag-run ID.
|
||||
func (cfg LogConfig) LogFile() string {
|
||||
|
||||
@ -167,6 +167,8 @@ func TestModel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
var _ execution.DAGStore = (*mockDAGStore)(nil)
|
||||
|
||||
// mockDAGStore is a mock implementation of models.DAGStore
|
||||
type mockDAGStore struct {
|
||||
mock.Mock
|
||||
|
||||
@ -65,21 +65,6 @@ func runEnqueue(ctx *Context, args []string) error {
|
||||
dag.Queue = queueOverride
|
||||
}
|
||||
|
||||
// Check queued DAG-runs
|
||||
queuedRuns, err := ctx.QueueStore.ListByDAGName(ctx, dag.ProcGroup(), dag.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read queue: %w", err)
|
||||
}
|
||||
|
||||
// If the DAG has a queue configured and maxActiveRuns > 1, ensure the number
|
||||
// of active runs in the queue does not exceed this limit.
|
||||
// No need to check if maxActiveRuns <= 1 for enqueueing as queue level
|
||||
// maxConcurrency will be the only cap.
|
||||
if dag.Queue != "" && dag.MaxActiveRuns > 1 && len(queuedRuns) >= dag.MaxActiveRuns {
|
||||
// The same DAG is already in the queue
|
||||
return fmt.Errorf("DAG %s is already in the queue (maxActiveRuns=%d), cannot enqueue", dag.Name, dag.MaxActiveRuns)
|
||||
}
|
||||
|
||||
return enqueueDAGRun(ctx, dag, runID)
|
||||
}
|
||||
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@ -21,7 +19,6 @@ const (
|
||||
flagWorkdir = "workdir"
|
||||
flagShell = "shell"
|
||||
flagBase = "base"
|
||||
flagSingleton = "singleton"
|
||||
defaultStepName = "main"
|
||||
execCommandUsage = "exec [flags] -- <command> [args...]"
|
||||
)
|
||||
@ -30,12 +27,9 @@ var (
|
||||
execFlags = []commandLineFlag{
|
||||
dagRunIDFlag,
|
||||
nameFlag,
|
||||
queueFlag,
|
||||
noQueueFlag,
|
||||
workdirFlag,
|
||||
shellFlag,
|
||||
baseFlag,
|
||||
singletonFlag,
|
||||
}
|
||||
)
|
||||
|
||||
@ -49,7 +43,7 @@ func Exec() *cobra.Command {
|
||||
Examples:
|
||||
dagu exec -- echo "hello world"
|
||||
dagu exec --env FOO=bar -- sh -c 'echo $FOO'
|
||||
dagu exec --queue nightly --worker-label role=batch -- python nightly.py`,
|
||||
dagu exec --worker-label role=batch -- python remote_script.py`,
|
||||
Args: cobra.ArbitraryArgs,
|
||||
}
|
||||
|
||||
@ -62,14 +56,12 @@ Examples:
|
||||
return command
|
||||
}
|
||||
|
||||
// runExec parses flags and arguments and executes the provided command as an inline DAG run,
|
||||
// either enqueueing it for distributed execution or running it immediately in-process.
|
||||
// It validates inputs (run-id, working directory, base and dotenv files, env vars, worker labels,
|
||||
// queue/singleton flags), builds the DAG for the inline command, and chooses between enqueueing
|
||||
// (when queues/worker labels require it or when max runs are reached) or direct execution.
|
||||
// runExec parses flags and arguments and executes the provided command as an inline DAG run.
|
||||
// It validates inputs (run-id, working directory, base and dotenv files, env vars, worker labels),
|
||||
// builds the DAG for the inline command, and executes it locally.
|
||||
// ctx provides CLI and application context; args are the command and its arguments.
|
||||
// Returns an error for validation failures, when a dag-run with the same run-id already exists,
|
||||
// or if enqueueing/execution fails.
|
||||
// or if execution fails.
|
||||
func runExec(ctx *Context, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("command is required (try: dagu exec -- <command>)")
|
||||
@ -177,29 +169,10 @@ func runExec(ctx *Context, args []string) error {
|
||||
workerLabels[key] = value
|
||||
}
|
||||
|
||||
queueName, err := ctx.Command.Flags().GetString("queue")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read queue flag: %w", err)
|
||||
}
|
||||
noQueue, err := ctx.Command.Flags().GetBool("no-queue")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read no-queue flag: %w", err)
|
||||
}
|
||||
|
||||
singleton, err := ctx.Command.Flags().GetBool(flagSingleton)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read singleton flag: %w", err)
|
||||
}
|
||||
|
||||
queueDisabled := !ctx.Config.Queues.Enabled || noQueue
|
||||
if len(workerLabels) > 0 {
|
||||
if !ctx.Config.Queues.Enabled {
|
||||
return fmt.Errorf("worker selector requires queues; enable queues or remove --worker-label")
|
||||
}
|
||||
if noQueue {
|
||||
return fmt.Errorf("--worker-label cannot be combined with --no-queue")
|
||||
}
|
||||
queueDisabled = false
|
||||
}
|
||||
|
||||
opts := ExecOptions{
|
||||
@ -210,9 +183,6 @@ func runExec(ctx *Context, args []string) error {
|
||||
Env: envVars,
|
||||
DotenvFiles: dotenvPaths,
|
||||
BaseConfig: baseConfig,
|
||||
Queue: queueName,
|
||||
NoQueue: noQueue,
|
||||
Singleton: singleton,
|
||||
WorkerLabels: workerLabels,
|
||||
}
|
||||
|
||||
@ -223,28 +193,6 @@ func runExec(ctx *Context, args []string) error {
|
||||
|
||||
dagRunRef := execution.NewDAGRunRef(dag.Name, runID)
|
||||
|
||||
if !queueDisabled && len(workerLabels) > 0 {
|
||||
logger.Info(ctx, "Queueing inline dag-run for distributed execution",
|
||||
tag.DAG(dag.Name),
|
||||
tag.RunID(runID),
|
||||
slog.Any("worker-selector", workerLabels),
|
||||
tag.Command(strings.Join(args, " ")),
|
||||
)
|
||||
dag.Location = ""
|
||||
return enqueueDAGRun(ctx, dag, runID)
|
||||
}
|
||||
|
||||
if !queueDisabled && dag.Queue != "" {
|
||||
logger.Info(ctx, "Queueing inline dag-run",
|
||||
tag.DAG(dag.Name),
|
||||
tag.Queue(dag.Queue),
|
||||
tag.RunID(runID),
|
||||
tag.Command(strings.Join(args, " ")),
|
||||
)
|
||||
dag.Location = ""
|
||||
return enqueueDAGRun(ctx, dag, runID)
|
||||
}
|
||||
|
||||
attempt, _ := ctx.DAGRunStore.FindAttempt(ctx, dagRunRef)
|
||||
if attempt != nil {
|
||||
return fmt.Errorf("dag-run ID %s already exists for DAG %s", runID, dag.Name)
|
||||
@ -256,16 +204,7 @@ func runExec(ctx *Context, args []string) error {
|
||||
tag.RunID(runID),
|
||||
)
|
||||
|
||||
err = tryExecuteDAG(ctx, dag, runID, dagRunRef, false)
|
||||
if errors.Is(err, errMaxRunReached) && !queueDisabled {
|
||||
logger.Info(ctx, "Max active runs reached; enqueueing dag-run instead",
|
||||
tag.DAG(dag.Name),
|
||||
tag.RunID(runID),
|
||||
)
|
||||
dag.Location = ""
|
||||
return enqueueDAGRun(ctx, dag, runID)
|
||||
}
|
||||
return err
|
||||
return tryExecuteDAG(ctx, dag, runID, dagRunRef)
|
||||
}
|
||||
|
||||
var (
|
||||
@ -281,9 +220,4 @@ var (
|
||||
name: flagBase,
|
||||
usage: "Path to a base DAG YAML whose defaults are applied before inline overrides",
|
||||
}
|
||||
singletonFlag = commandLineFlag{
|
||||
name: flagSingleton,
|
||||
usage: "Limit execution to a single active run (sets maxActiveRuns=1)",
|
||||
isBool: true,
|
||||
}
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/common/cmdutil"
|
||||
"github.com/dagu-org/dagu/internal/common/fileutil"
|
||||
"github.com/dagu-org/dagu/internal/common/stringutil"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
@ -22,9 +23,6 @@ type ExecOptions struct {
|
||||
Env []string
|
||||
DotenvFiles []string
|
||||
BaseConfig string
|
||||
Queue string
|
||||
NoQueue bool
|
||||
Singleton bool
|
||||
WorkerLabels map[string]string
|
||||
}
|
||||
|
||||
@ -34,16 +32,14 @@ type execSpec struct {
|
||||
WorkingDir string `yaml:"workingDir,omitempty"`
|
||||
Env []string `yaml:"env,omitempty"`
|
||||
Dotenv []string `yaml:"dotenv,omitempty"`
|
||||
MaxActiveRuns int `yaml:"maxActiveRuns,omitempty"`
|
||||
Queue string `yaml:"queue,omitempty"`
|
||||
WorkerSelector map[string]string `yaml:"workerSelector,omitempty"`
|
||||
Steps []execStep `yaml:"steps"`
|
||||
}
|
||||
|
||||
type execStep struct {
|
||||
Name string `yaml:"name"`
|
||||
Command []string `yaml:"command,omitempty"`
|
||||
Shell string `yaml:"shell,omitempty"`
|
||||
Name string `yaml:"name"`
|
||||
Command string `yaml:"command,omitempty"`
|
||||
Shell string `yaml:"shell,omitempty"`
|
||||
}
|
||||
|
||||
func buildExecDAG(ctx *Context, opts ExecOptions) (*core.DAG, string, error) {
|
||||
@ -59,15 +55,8 @@ func buildExecDAG(ctx *Context, opts ExecOptions) (*core.DAG, string, error) {
|
||||
return nil, "", fmt.Errorf("invalid DAG name: %w", err)
|
||||
}
|
||||
|
||||
maxActiveRuns := -1
|
||||
if opts.Singleton {
|
||||
maxActiveRuns = 1
|
||||
}
|
||||
|
||||
queueValue := ""
|
||||
if opts.Queue != "" && !opts.NoQueue {
|
||||
queueValue = opts.Queue
|
||||
}
|
||||
// Build command string from args
|
||||
commandStr := cmdutil.BuildCommandEscapedString(opts.CommandArgs[0], opts.CommandArgs[1:])
|
||||
|
||||
specDoc := execSpec{
|
||||
Name: name,
|
||||
@ -75,13 +64,11 @@ func buildExecDAG(ctx *Context, opts ExecOptions) (*core.DAG, string, error) {
|
||||
WorkingDir: opts.WorkingDir,
|
||||
Env: opts.Env,
|
||||
Dotenv: opts.DotenvFiles,
|
||||
MaxActiveRuns: maxActiveRuns,
|
||||
Queue: queueValue,
|
||||
WorkerSelector: opts.WorkerLabels,
|
||||
Steps: []execStep{
|
||||
{
|
||||
Name: defaultStepName,
|
||||
Command: opts.CommandArgs,
|
||||
Command: commandStr,
|
||||
Shell: opts.ShellOverride,
|
||||
},
|
||||
},
|
||||
@ -127,19 +114,10 @@ func buildExecDAG(ctx *Context, opts ExecOptions) (*core.DAG, string, error) {
|
||||
|
||||
dag.Name = name
|
||||
dag.WorkingDir = opts.WorkingDir
|
||||
if opts.Queue != "" && !opts.NoQueue {
|
||||
dag.Queue = opts.Queue
|
||||
} else if opts.NoQueue {
|
||||
dag.Queue = ""
|
||||
}
|
||||
if len(opts.WorkerLabels) > 0 {
|
||||
dag.WorkerSelector = opts.WorkerLabels
|
||||
}
|
||||
if opts.Singleton {
|
||||
dag.MaxActiveRuns = 1
|
||||
} else {
|
||||
dag.MaxActiveRuns = -1
|
||||
}
|
||||
dag.MaxActiveRuns = -1
|
||||
dag.Location = ""
|
||||
|
||||
return dag, string(specYAML), nil
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/dagu-org/dagu/internal/common/config"
|
||||
"github.com/dagu-org/dagu/internal/common/stringutil"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
@ -68,14 +67,6 @@ var (
|
||||
usage: "Override the DAG name (default: name from DAG definition or filename)",
|
||||
}
|
||||
|
||||
// noQueueFlag is used to indicate that the dag-run should not be queued and should be executed immediately.
|
||||
noQueueFlag = commandLineFlag{
|
||||
name: "no-queue",
|
||||
usage: "Do not queue the dag-run, execute immediately",
|
||||
isBool: true,
|
||||
shorthand: "n",
|
||||
}
|
||||
|
||||
// Unique dag-run ID required for retrying a dag-run.
|
||||
// This flag must be provided when using the retry command.
|
||||
dagRunIDFlagRetry = commandLineFlag{
|
||||
@ -93,15 +84,6 @@ var (
|
||||
defaultValue: "",
|
||||
}
|
||||
|
||||
// noCheckMaxActiveRuns
|
||||
disableMaxActiveRuns = commandLineFlag{
|
||||
name: "disable-max-active-runs",
|
||||
shorthand: "",
|
||||
usage: "Disable check for max active runs",
|
||||
isBool: true,
|
||||
defaultValue: "",
|
||||
}
|
||||
|
||||
// Unique dag-run ID used for starting a new dag-run.
|
||||
// This is used to track and identify the execution instance and its status.
|
||||
dagRunIDFlag = commandLineFlag{
|
||||
@ -267,6 +249,30 @@ var (
|
||||
isBool: true,
|
||||
bindViper: true,
|
||||
}
|
||||
|
||||
// retentionDaysFlag specifies the number of days to retain history.
|
||||
// Records older than this will be deleted.
|
||||
// If set to 0, all records (except active) will be deleted.
|
||||
retentionDaysFlag = commandLineFlag{
|
||||
name: "retention-days",
|
||||
defaultValue: "0",
|
||||
usage: "Number of days to retain history (0 = delete all, except active runs)",
|
||||
}
|
||||
|
||||
// dryRunFlag enables preview mode without actual deletion.
|
||||
dryRunFlag = commandLineFlag{
|
||||
name: "dry-run",
|
||||
usage: "Preview what would be deleted without actually deleting",
|
||||
isBool: true,
|
||||
}
|
||||
|
||||
// yesFlag skips the confirmation prompt.
|
||||
yesFlag = commandLineFlag{
|
||||
name: "yes",
|
||||
shorthand: "y",
|
||||
usage: "Skip confirmation prompt",
|
||||
isBool: true,
|
||||
}
|
||||
)
|
||||
|
||||
type commandLineFlag struct {
|
||||
@ -296,16 +302,13 @@ func initFlags(cmd *cobra.Command, additionalFlags ...commandLineFlag) {
|
||||
|
||||
// bindFlags binds command-line flags to the provided Viper instance for configuration lookup.
|
||||
// It binds only flags whose `bindViper` field is true, using the camel-cased key produced
|
||||
// from each flag's kebab-case name. Binding is performed while holding the config package's
|
||||
// Viper lock to ensure thread-safe registration.
|
||||
// from each flag's kebab-case name.
|
||||
func bindFlags(viper *viper.Viper, cmd *cobra.Command, additionalFlags ...commandLineFlag) {
|
||||
flags := append([]commandLineFlag{configFlag}, additionalFlags...)
|
||||
|
||||
config.WithViperLock(func() {
|
||||
for _, flag := range flags {
|
||||
if flag.bindViper {
|
||||
_ = viper.BindPFlag(stringutil.KebabToCamel(flag.name), cmd.Flags().Lookup(flag.name))
|
||||
}
|
||||
for _, flag := range flags {
|
||||
if flag.bindViper {
|
||||
_ = viper.BindPFlag(stringutil.KebabToCamel(flag.name), cmd.Flags().Lookup(flag.name))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,6 +17,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var _ execution.DAGStore = (*mockDAGStore)(nil)
|
||||
|
||||
// mockDAGStore implements models.DAGStore for testing
|
||||
type mockDAGStore struct {
|
||||
dags map[string]*core.DAG
|
||||
|
||||
182
internal/cmd/record_early_failure_test.go
Normal file
182
internal/cmd/record_early_failure_test.go
Normal file
@ -0,0 +1,182 @@
|
||||
package cmd_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/cmd"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/execution"
|
||||
"github.com/dagu-org/dagu/internal/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecordEarlyFailure(t *testing.T) {
|
||||
t.Run("RecordsFailureForNewDAGRun", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo hello
|
||||
`)
|
||||
|
||||
dagRunID := "test-run-id-001"
|
||||
testErr := errors.New("process acquisition failed")
|
||||
|
||||
// Create Context with required stores
|
||||
ctx := &cmd.Context{
|
||||
Context: th.Context,
|
||||
Config: th.Config,
|
||||
DAGRunStore: th.DAGRunStore,
|
||||
}
|
||||
|
||||
// Record the early failure
|
||||
err := ctx.RecordEarlyFailure(dag.DAG, dagRunID, testErr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the failure was recorded
|
||||
ref := execution.NewDAGRunRef(dag.Name, dagRunID)
|
||||
attempt, err := th.DAGRunStore.FindAttempt(th.Context, ref)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attempt)
|
||||
|
||||
status, err := attempt.ReadStatus(th.Context)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, core.Failed, status.Status)
|
||||
require.Contains(t, status.Error, "process acquisition failed")
|
||||
})
|
||||
|
||||
t.Run("RecordsFailureForExistingAttempt", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo hello
|
||||
`)
|
||||
|
||||
// First, run the DAG to create an attempt
|
||||
th.RunCommand(t, cmd.Start(), test.CmdTest{
|
||||
Args: []string{"start", dag.Location},
|
||||
})
|
||||
|
||||
// Get the existing run ID
|
||||
latestStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
|
||||
require.NoError(t, err)
|
||||
dagRunID := latestStatus.DAGRunID
|
||||
|
||||
// Now record an early failure for the same run ID
|
||||
testErr := errors.New("retry failed due to lock contention")
|
||||
|
||||
ctx := &cmd.Context{
|
||||
Context: th.Context,
|
||||
Config: th.Config,
|
||||
DAGRunStore: th.DAGRunStore,
|
||||
}
|
||||
|
||||
err = ctx.RecordEarlyFailure(dag.DAG, dagRunID, testErr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the failure was recorded (status should be updated)
|
||||
ref := execution.NewDAGRunRef(dag.Name, dagRunID)
|
||||
attempt, err := th.DAGRunStore.FindAttempt(th.Context, ref)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attempt)
|
||||
|
||||
status, err := attempt.ReadStatus(th.Context)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, core.Failed, status.Status)
|
||||
require.Contains(t, status.Error, "retry failed due to lock contention")
|
||||
})
|
||||
|
||||
t.Run("ReturnsErrorForNilDAG", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
ctx := &cmd.Context{
|
||||
Context: th.Context,
|
||||
Config: th.Config,
|
||||
DAGRunStore: th.DAGRunStore,
|
||||
}
|
||||
|
||||
err := ctx.RecordEarlyFailure(nil, "some-run-id", errors.New("test error"))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "DAG and dag-run ID are required")
|
||||
})
|
||||
|
||||
t.Run("ReturnsErrorForEmptyDAGRunID", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo hello
|
||||
`)
|
||||
|
||||
ctx := &cmd.Context{
|
||||
Context: th.Context,
|
||||
Config: th.Config,
|
||||
DAGRunStore: th.DAGRunStore,
|
||||
}
|
||||
|
||||
err := ctx.RecordEarlyFailure(dag.DAG, "", errors.New("test error"))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "DAG and dag-run ID are required")
|
||||
})
|
||||
|
||||
t.Run("CanRetryEarlyFailureRecord", func(t *testing.T) {
|
||||
th := test.SetupCommand(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo hello
|
||||
`)
|
||||
|
||||
dagRunID := "early-failure-retry-test"
|
||||
testErr := errors.New("initial process acquisition failed")
|
||||
|
||||
// Create Context and record early failure
|
||||
ctx := &cmd.Context{
|
||||
Context: th.Context,
|
||||
Config: th.Config,
|
||||
DAGRunStore: th.DAGRunStore,
|
||||
}
|
||||
|
||||
err := ctx.RecordEarlyFailure(dag.DAG, dagRunID, testErr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify initial failure status
|
||||
ref := execution.NewDAGRunRef(dag.Name, dagRunID)
|
||||
attempt, err := th.DAGRunStore.FindAttempt(th.Context, ref)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, attempt)
|
||||
|
||||
status, err := attempt.ReadStatus(th.Context)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, core.Failed, status.Status)
|
||||
|
||||
// Verify DAG can be read back (required for retry)
|
||||
storedDAG, err := attempt.ReadDAG(th.Context)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedDAG)
|
||||
require.Equal(t, dag.Name, storedDAG.Name)
|
||||
|
||||
// Now retry the early failure record
|
||||
th.RunCommand(t, cmd.Retry(), test.CmdTest{
|
||||
Args: []string{"retry", "--run-id", dagRunID, dag.Name},
|
||||
})
|
||||
|
||||
// Wait for retry to complete
|
||||
require.Eventually(t, func() bool {
|
||||
currentStatus, err := th.DAGRunMgr.GetCurrentStatus(th.Context, dag.DAG, dagRunID)
|
||||
return err == nil && currentStatus != nil && currentStatus.Status == core.Succeeded
|
||||
}, 5*time.Second, 100*time.Millisecond, "Retry should succeed")
|
||||
|
||||
// Verify final status is succeeded
|
||||
finalStatus, err := th.DAGRunMgr.GetCurrentStatus(th.Context, dag.DAG, dagRunID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, core.Succeeded, finalStatus.Status)
|
||||
})
|
||||
}
|
||||
@ -87,9 +87,9 @@ func runRestart(ctx *Context, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleRestartProcess(ctx *Context, d *core.DAG, dagRunID string) error {
|
||||
func handleRestartProcess(ctx *Context, d *core.DAG, oldDagRunID string) error {
|
||||
// Stop if running
|
||||
if err := stopDAGIfRunning(ctx, ctx.DAGRunMgr, d, dagRunID); err != nil {
|
||||
if err := stopDAGIfRunning(ctx, ctx.DAGRunMgr, d, oldDagRunID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -99,36 +99,40 @@ func handleRestartProcess(ctx *Context, d *core.DAG, dagRunID string) error {
|
||||
time.Sleep(d.RestartWait)
|
||||
}
|
||||
|
||||
// Generate new dag-run ID for the restart
|
||||
newDagRunID, err := genRunID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate dag-run ID: %w", err)
|
||||
}
|
||||
|
||||
// Execute the exact same DAG with the same parameters but a new dag-run ID
|
||||
if err := ctx.ProcStore.Lock(ctx, d.ProcGroup()); err != nil {
|
||||
logger.Debug(ctx, "Failed to lock process group", tag.Error(err))
|
||||
return errMaxRunReached
|
||||
_ = ctx.RecordEarlyFailure(d, newDagRunID, err)
|
||||
return errProcAcquisitionFailed
|
||||
}
|
||||
defer ctx.ProcStore.Unlock(ctx, d.ProcGroup())
|
||||
|
||||
// Acquire process handle
|
||||
proc, err := ctx.ProcStore.Acquire(ctx, d.ProcGroup(), execution.NewDAGRunRef(d.Name, dagRunID))
|
||||
proc, err := ctx.ProcStore.Acquire(ctx, d.ProcGroup(), execution.NewDAGRunRef(d.Name, newDagRunID))
|
||||
if err != nil {
|
||||
ctx.ProcStore.Unlock(ctx, d.ProcGroup())
|
||||
logger.Debug(ctx, "Failed to acquire process handle", tag.Error(err))
|
||||
return fmt.Errorf("failed to acquire process handle: %w", errMaxRunReached)
|
||||
_ = ctx.RecordEarlyFailure(d, newDagRunID, err)
|
||||
return fmt.Errorf("failed to acquire process handle: %w", errProcAcquisitionFailed)
|
||||
}
|
||||
defer func() {
|
||||
_ = proc.Stop(ctx)
|
||||
}()
|
||||
|
||||
// Unlock the process group
|
||||
// Unlock the process group immediately after acquiring the handle
|
||||
ctx.ProcStore.Unlock(ctx, d.ProcGroup())
|
||||
|
||||
return executeDAG(ctx, ctx.DAGRunMgr, d)
|
||||
return executeDAGWithRunID(ctx, ctx.DAGRunMgr, d, newDagRunID)
|
||||
}
|
||||
|
||||
// It returns an error if run ID generation, log or DAG store initialization, or agent execution fails.
|
||||
func executeDAG(ctx *Context, cli runtime.Manager, dag *core.DAG) error {
|
||||
dagRunID, err := genRunID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate dag-run ID: %w", err)
|
||||
}
|
||||
|
||||
// executeDAGWithRunID executes a DAG with a pre-generated run ID.
|
||||
// It returns an error if log or DAG store initialization, or agent execution fails.
|
||||
func executeDAGWithRunID(ctx *Context, cli runtime.Manager, dag *core.DAG, dagRunID string) error {
|
||||
logFile, err := ctx.OpenLogFile(dag, dagRunID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize log file: %w", err)
|
||||
|
||||
@ -2,18 +2,14 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/common/fileutil"
|
||||
"github.com/dagu-org/dagu/internal/common/logger"
|
||||
"github.com/dagu-org/dagu/internal/common/logger/tag"
|
||||
"github.com/dagu-org/dagu/internal/common/stringutil"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/execution"
|
||||
"github.com/dagu-org/dagu/internal/runtime/agent"
|
||||
"github.com/dagu-org/dagu/internal/runtime/transform"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@ -37,12 +33,12 @@ Examples:
|
||||
)
|
||||
}
|
||||
|
||||
var retryFlags = []commandLineFlag{dagRunIDFlagRetry, stepNameForRetry, disableMaxActiveRuns, noQueueFlag}
|
||||
var retryFlags = []commandLineFlag{dagRunIDFlagRetry, stepNameForRetry}
|
||||
|
||||
func runRetry(ctx *Context, args []string) error {
|
||||
// Extract retry details
|
||||
dagRunID, _ := ctx.StringParam("run-id")
|
||||
stepName, _ := ctx.StringParam("step")
|
||||
disableMaxActiveRuns := ctx.Command.Flags().Changed("disable-max-active-runs")
|
||||
|
||||
name, err := extractDAGName(ctx, args[0])
|
||||
if err != nil {
|
||||
@ -71,56 +67,18 @@ func runRetry(ctx *Context, args []string) error {
|
||||
// Set DAG context for all logs
|
||||
ctx.Context = logger.WithValues(ctx.Context, tag.DAG(dag.Name), tag.RunID(dagRunID))
|
||||
|
||||
// Check if queue is disabled via config or flag
|
||||
queueDisabled := !ctx.Config.Queues.Enabled || ctx.Command.Flags().Changed("no-queue")
|
||||
|
||||
// Check if this DAG should be distributed to workers
|
||||
// If the DAG has a workerSelector and the queue is not disabled,
|
||||
// enqueue it so the scheduler can dispatch it to a worker.
|
||||
// The --no-queue flag acts as a circuit breaker to prevent infinite loops
|
||||
// when the worker executes the dispatched retry task.
|
||||
if !queueDisabled && len(dag.WorkerSelector) > 0 {
|
||||
logger.Info(ctx, "DAG has workerSelector, enqueueing retry for distributed execution", slog.Any("worker-selector", dag.WorkerSelector))
|
||||
|
||||
// Enqueue the retry - must create new attempt with status "Queued"
|
||||
// so the scheduler will process it
|
||||
if err := enqueueRetry(ctx, dag, dagRunID); err != nil {
|
||||
return fmt.Errorf("failed to enqueue retry: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(ctx.Context, "Retry enqueued")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try lock proc store to avoid race
|
||||
if err := ctx.ProcStore.Lock(ctx, dag.ProcGroup()); err != nil {
|
||||
return fmt.Errorf("failed to lock process group: %w", err)
|
||||
}
|
||||
defer ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
|
||||
|
||||
if !disableMaxActiveRuns {
|
||||
liveCount, err := ctx.ProcStore.CountAliveByDAGName(ctx, dag.ProcGroup(), dag.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to access proc store: %w", err)
|
||||
}
|
||||
// Count queued DAG-runs and return error if the total of a new run plus
|
||||
// active runs will exceed the maxActiveRuns.
|
||||
queuedRuns, err := ctx.QueueStore.ListByDAGName(ctx, dag.ProcGroup(), dag.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read queue: %w", err)
|
||||
}
|
||||
// If the DAG has a queue configured and maxActiveRuns > 0, ensure the number
|
||||
// of active runs in the queue does not exceed this limit.
|
||||
if dag.MaxActiveRuns > 0 && len(queuedRuns)+liveCount >= dag.MaxActiveRuns {
|
||||
return fmt.Errorf("DAG %s is already in the queue (maxActiveRuns=%d), cannot start", dag.Name, dag.MaxActiveRuns)
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire process handle
|
||||
proc, err := ctx.ProcStore.Acquire(ctx, dag.ProcGroup(), execution.NewDAGRunRef(dag.Name, dagRunID))
|
||||
if err != nil {
|
||||
ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
|
||||
logger.Debug(ctx, "Failed to acquire process handle", tag.Error(err))
|
||||
return fmt.Errorf("failed to acquire process handle: %w", errMaxRunReached)
|
||||
_ = ctx.RecordEarlyFailure(dag, dagRunID, err)
|
||||
return fmt.Errorf("failed to acquire process handle: %w", errProcAcquisitionFailed)
|
||||
}
|
||||
defer func() {
|
||||
_ = proc.Stop(ctx)
|
||||
@ -189,65 +147,3 @@ func executeRetry(ctx *Context, dag *core.DAG, status *execution.DAGRunStatus, r
|
||||
// Use the shared agent execution function
|
||||
return ExecuteAgent(ctx, agentInstance, dag, status.DAGRunID, logFile)
|
||||
}
|
||||
|
||||
// enqueueRetry creates a new attempt for retry and enqueues it for execution
|
||||
func enqueueRetry(ctx *Context, dag *core.DAG, dagRunID string) error {
|
||||
// Queued dag-runs must not have a location because it is used to generate
|
||||
// unix pipe. If two DAGs has same location, they can not run at the same time.
|
||||
// Queued DAGs can be run at the same time depending on the `maxActiveRuns` setting.
|
||||
dag.Location = ""
|
||||
|
||||
// Check if queues are enabled
|
||||
if !ctx.Config.Queues.Enabled {
|
||||
return fmt.Errorf("queues are disabled in configuration")
|
||||
}
|
||||
|
||||
// Create a new attempt for retry
|
||||
att, err := ctx.DAGRunStore.CreateAttempt(ctx.Context, dag, time.Now(), dagRunID, execution.NewDAGRunAttemptOptions{
|
||||
Retry: true,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create retry attempt: %w", err)
|
||||
}
|
||||
|
||||
// Generate log file name
|
||||
logFile, err := ctx.GenLogFileName(dag, dagRunID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate log file name: %w", err)
|
||||
}
|
||||
|
||||
// Create status for the new attempt with "Queued" status
|
||||
opts := []transform.StatusOption{
|
||||
transform.WithLogFilePath(logFile),
|
||||
transform.WithAttemptID(att.ID()),
|
||||
transform.WithPreconditions(dag.Preconditions),
|
||||
transform.WithQueuedAt(stringutil.FormatTime(time.Now())),
|
||||
transform.WithHierarchyRefs(
|
||||
execution.NewDAGRunRef(dag.Name, dagRunID),
|
||||
execution.DAGRunRef{},
|
||||
),
|
||||
}
|
||||
|
||||
dagStatus := transform.NewStatusBuilder(dag).Create(dagRunID, core.Queued, 0, time.Time{}, opts...)
|
||||
|
||||
// Write the status
|
||||
if err := att.Open(ctx.Context); err != nil {
|
||||
return fmt.Errorf("failed to open attempt: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = att.Close(ctx.Context)
|
||||
}()
|
||||
if err := att.Write(ctx.Context, dagStatus); err != nil {
|
||||
return fmt.Errorf("failed to save status: %w", err)
|
||||
}
|
||||
|
||||
// Enqueue the DAG run
|
||||
dagRun := execution.NewDAGRunRef(dag.Name, dagRunID)
|
||||
if err := ctx.QueueStore.Enqueue(ctx.Context, dag.ProcGroup(), execution.QueuePriorityLow, dagRun); err != nil {
|
||||
return fmt.Errorf("failed to enqueue: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(ctx, "Retry attempt created and enqueued", tag.AttemptID(att.ID()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ This command parses the DAG definition, resolves parameters, and initiates the D
|
||||
}
|
||||
|
||||
// Command line flags for the start command
|
||||
var startFlags = []commandLineFlag{paramsFlag, nameFlag, dagRunIDFlag, fromRunIDFlag, parentDAGRunFlag, rootDAGRunFlag, noQueueFlag, disableMaxActiveRuns, defaultWorkingDirFlag}
|
||||
var startFlags = []commandLineFlag{paramsFlag, nameFlag, dagRunIDFlag, fromRunIDFlag, parentDAGRunFlag, rootDAGRunFlag, defaultWorkingDirFlag}
|
||||
|
||||
var fromRunIDFlag = commandLineFlag{
|
||||
name: "from-run-id",
|
||||
@ -81,8 +81,6 @@ func runStart(ctx *Context, args []string) error {
|
||||
return fmt.Errorf("--from-run-id cannot be combined with --parent or --root")
|
||||
}
|
||||
|
||||
disableMaxActiveRuns := ctx.Command.Flags().Changed("disable-max-active-runs")
|
||||
|
||||
var (
|
||||
dag *core.DAG
|
||||
params string
|
||||
@ -157,14 +155,6 @@ func runStart(ctx *Context, args []string) error {
|
||||
return handleSubDAGRun(ctx, dag, dagRunID, params, root, parent)
|
||||
}
|
||||
|
||||
// Check if queue is disabled via config or flag
|
||||
queueDisabled := !ctx.Config.Queues.Enabled
|
||||
|
||||
// check no-queue flag (overrides config)
|
||||
if ctx.Command.Flags().Changed("no-queue") {
|
||||
queueDisabled = true
|
||||
}
|
||||
|
||||
// Check if the DAG run-id is unique
|
||||
attempt, _ := ctx.DAGRunStore.FindAttempt(ctx, root)
|
||||
if attempt != nil {
|
||||
@ -172,15 +162,6 @@ func runStart(ctx *Context, args []string) error {
|
||||
return fmt.Errorf("dag-run ID %s already exists for DAG %s", dagRunID, dag.Name)
|
||||
}
|
||||
|
||||
// Count running DAG to check against maxActiveRuns setting (best effort).
|
||||
liveCount, err := ctx.ProcStore.CountAliveByDAGName(ctx, dag.ProcGroup(), dag.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to access proc store: %w", err)
|
||||
}
|
||||
if !disableMaxActiveRuns && dag.MaxActiveRuns == 1 && liveCount > 0 {
|
||||
return fmt.Errorf("DAG %s is already running, cannot start", dag.Name)
|
||||
}
|
||||
|
||||
// Log root dag-run or reschedule action
|
||||
if fromRunID != "" {
|
||||
logger.Info(ctx, "Rescheduling dag-run",
|
||||
@ -191,78 +172,36 @@ func runStart(ctx *Context, args []string) error {
|
||||
logger.Info(ctx, "Executing root dag-run", slog.String("params", params))
|
||||
}
|
||||
|
||||
// Check if this DAG should be distributed to workers
|
||||
// If the DAG has a workerSelector and the queue is not disabled,
|
||||
// enqueue it so the scheduler can dispatch it to a worker.
|
||||
// The --no-queue flag acts as a circuit breaker to prevent infinite loops
|
||||
// when the worker executes the dispatched task.
|
||||
if !queueDisabled && len(dag.WorkerSelector) > 0 {
|
||||
logger.Info(ctx, "DAG has workerSelector, enqueueing for distributed execution", slog.Any("worker-selector", dag.WorkerSelector))
|
||||
dag.Location = "" // Queued dag-runs must not have a location
|
||||
return enqueueDAGRun(ctx, dag, dagRunID)
|
||||
}
|
||||
|
||||
err = tryExecuteDAG(ctx, dag, dagRunID, root, disableMaxActiveRuns)
|
||||
if errors.Is(err, errMaxRunReached) && !queueDisabled && !disableMaxActiveRuns {
|
||||
dag.Location = "" // Queued dag-runs must not have a location
|
||||
|
||||
// If the DAG has a queue configured and maxActiveRuns > 1, ensure the number
|
||||
// of active runs in the queue does not exceed this limit.
|
||||
// The scheduler only enforces maxActiveRuns at the global queue level.
|
||||
queuedRuns, err := ctx.QueueStore.ListByDAGName(ctx, dag.ProcGroup(), dag.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read queue: %w", err)
|
||||
}
|
||||
if dag.Queue != "" && dag.MaxActiveRuns > 1 && len(queuedRuns)+liveCount >= dag.MaxActiveRuns {
|
||||
return fmt.Errorf("DAG %s is already in the queue (maxActiveRuns=%d), cannot start", dag.Name, dag.MaxActiveRuns)
|
||||
}
|
||||
|
||||
// Enqueue the DAG-run for execution
|
||||
return enqueueDAGRun(ctx, dag, dagRunID)
|
||||
}
|
||||
|
||||
return err // return executed result
|
||||
return tryExecuteDAG(ctx, dag, dagRunID, root)
|
||||
}
|
||||
|
||||
var (
|
||||
errMaxRunReached = errors.New("max run reached")
|
||||
errProcAcquisitionFailed = errors.New("failed to acquire process handle")
|
||||
)
|
||||
|
||||
// tryExecuteDAG tries to run the DAG within the max concurrent run config
|
||||
func tryExecuteDAG(ctx *Context, dag *core.DAG, dagRunID string, root execution.DAGRunRef, disableMaxActiveRuns bool) error {
|
||||
// tryExecuteDAG tries to run the DAG.
|
||||
func tryExecuteDAG(ctx *Context, dag *core.DAG, dagRunID string, root execution.DAGRunRef) error {
|
||||
if err := ctx.ProcStore.Lock(ctx, dag.ProcGroup()); err != nil {
|
||||
logger.Debug(ctx, "Failed to lock process group", tag.Error(err))
|
||||
return errMaxRunReached
|
||||
}
|
||||
defer ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
|
||||
|
||||
if !disableMaxActiveRuns {
|
||||
runningCount, err := ctx.ProcStore.CountAlive(ctx, dag.ProcGroup())
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "Failed to count live processes", tag.Error(err))
|
||||
return fmt.Errorf("failed to count live process for %s: %w", dag.ProcGroup(), errMaxRunReached)
|
||||
}
|
||||
|
||||
// If the DAG has a queue configured and maxActiveRuns > 0, ensure the number
|
||||
// of active runs in the queue does not exceed this limit.
|
||||
if dag.MaxActiveRuns > 0 && runningCount >= dag.MaxActiveRuns {
|
||||
// It's not possible to run right now.
|
||||
return fmt.Errorf("max active run is reached (%d >= %d): %w", runningCount, dag.MaxActiveRuns, errMaxRunReached)
|
||||
}
|
||||
_ = ctx.RecordEarlyFailure(dag, dagRunID, err)
|
||||
return errProcAcquisitionFailed
|
||||
}
|
||||
|
||||
// Acquire process handle
|
||||
proc, err := ctx.ProcStore.Acquire(ctx, dag.ProcGroup(), execution.NewDAGRunRef(dag.Name, dagRunID))
|
||||
if err != nil {
|
||||
ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
|
||||
logger.Debug(ctx, "Failed to acquire process handle", tag.Error(err))
|
||||
return fmt.Errorf("failed to acquire process handle: %w", errMaxRunReached)
|
||||
_ = ctx.RecordEarlyFailure(dag, dagRunID, err)
|
||||
return fmt.Errorf("failed to acquire process handle: %w", errProcAcquisitionFailed)
|
||||
}
|
||||
defer func() {
|
||||
_ = proc.Stop(ctx)
|
||||
}()
|
||||
ctx.Proc = proc
|
||||
|
||||
// Unlock the process group
|
||||
// Unlock the process group immediately after acquiring the handle
|
||||
// to allow other instances of the same DAG to start.
|
||||
ctx.ProcStore.Unlock(ctx, dag.ProcGroup())
|
||||
|
||||
return executeDAGRun(ctx, dag, execution.DAGRunRef{}, dagRunID, root)
|
||||
|
||||
@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/dagu-org/dagu/internal/common/logger"
|
||||
"github.com/dagu-org/dagu/internal/common/logger/tag"
|
||||
"github.com/dagu-org/dagu/internal/service/coordinator"
|
||||
"github.com/dagu-org/dagu/internal/service/resource"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@ -103,16 +102,9 @@ func runStartAll(ctx *Context, _ []string) error {
|
||||
}
|
||||
|
||||
// Only start coordinator if not bound to localhost
|
||||
var coordinator *coordinator.Service
|
||||
enableCoordinator := ctx.Config.Coordinator.Host != "127.0.0.1" && ctx.Config.Coordinator.Host != "localhost" && ctx.Config.Coordinator.Host != "::1"
|
||||
|
||||
if enableCoordinator {
|
||||
coordinator, err = newCoordinator(ctx, ctx.Config, ctx.ServiceRegistry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize coordinator: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info(ctx, "Coordinator disabled (bound to localhost), set --coordinator.host and --coordinator.advertise to enable distributed mode")
|
||||
coordinator, err := newCoordinator(ctx, ctx.Config, ctx.ServiceRegistry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize coordinator: %w", err)
|
||||
}
|
||||
|
||||
// Create a new context with the signal context for services
|
||||
@ -136,10 +128,7 @@ func runStartAll(ctx *Context, _ []string) error {
|
||||
|
||||
// WaitGroup to track all services
|
||||
var wg sync.WaitGroup
|
||||
serviceCount := 2 // scheduler + server
|
||||
if enableCoordinator {
|
||||
serviceCount = 3 // + coordinator
|
||||
}
|
||||
serviceCount := 3 // scheduler + server + coordinator
|
||||
errCh := make(chan error, serviceCount)
|
||||
|
||||
// Start scheduler
|
||||
@ -157,19 +146,16 @@ func runStartAll(ctx *Context, _ []string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
// Start coordinator (if enabled)
|
||||
if enableCoordinator {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := coordinator.Start(serviceCtx); err != nil {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("coordinator failed: %w", err):
|
||||
default:
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := coordinator.Start(serviceCtx); err != nil {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("coordinator failed: %w", err):
|
||||
default:
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start server
|
||||
wg.Add(1)
|
||||
@ -203,13 +189,11 @@ func runStartAll(ctx *Context, _ []string) error {
|
||||
// Stop all services gracefully
|
||||
logger.Info(ctx, "Stopping all services")
|
||||
|
||||
// Stop coordinator first to unregister from service registry (if it was started)
|
||||
if enableCoordinator && coordinator != nil {
|
||||
if err := coordinator.Stop(ctx); err != nil {
|
||||
logger.Error(ctx, "Failed to stop coordinator",
|
||||
tag.Error(err),
|
||||
)
|
||||
}
|
||||
// Stop coordinator first to unregister from service registry
|
||||
if err := coordinator.Stop(ctx); err != nil {
|
||||
logger.Error(ctx, "Failed to stop coordinator",
|
||||
tag.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
// Stop resource service
|
||||
|
||||
@ -224,6 +224,8 @@ func TestRetrier_Next(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
var _ RetryPolicy = (*mockRetryPolicy)(nil)
|
||||
|
||||
// mockRetryPolicy is a test helper that returns predefined intervals
|
||||
type mockRetryPolicy struct {
|
||||
intervals []time.Duration
|
||||
|
||||
@ -266,6 +266,14 @@ func processStructFields(ctx context.Context, v reflect.Value, opts *EvalOptions
|
||||
|
||||
// nolint:exhaustive
|
||||
switch field.Kind() {
|
||||
case reflect.Ptr:
|
||||
if field.IsNil() {
|
||||
continue
|
||||
}
|
||||
if err := processPointerField(ctx, field, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case reflect.String:
|
||||
value := field.String()
|
||||
|
||||
@ -302,6 +310,137 @@ func processStructFields(ctx context.Context, v reflect.Value, opts *EvalOptions
|
||||
}
|
||||
|
||||
field.Set(processed)
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
if field.IsNil() {
|
||||
continue
|
||||
}
|
||||
// Create a new slice with new backing array to avoid mutating original
|
||||
newSlice := reflect.MakeSlice(field.Type(), field.Len(), field.Cap())
|
||||
reflect.Copy(newSlice, field)
|
||||
if err := processSliceWithOpts(ctx, newSlice, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
field.Set(newSlice)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processPointerField(ctx context.Context, field reflect.Value, opts *EvalOptions) error {
|
||||
elem := field.Elem()
|
||||
if !elem.CanSet() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:exhaustive
|
||||
switch elem.Kind() {
|
||||
case reflect.String:
|
||||
value := elem.String()
|
||||
value = expandVariables(ctx, value, opts)
|
||||
if opts.Substitute {
|
||||
var err error
|
||||
value, err = substituteCommands(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opts.ExpandEnv {
|
||||
value = os.ExpandEnv(value)
|
||||
}
|
||||
// Create a new string and update the pointer to point to it
|
||||
// This avoids mutating the original value
|
||||
newStr := reflect.New(elem.Type())
|
||||
newStr.Elem().SetString(value)
|
||||
field.Set(newStr)
|
||||
|
||||
case reflect.Struct:
|
||||
// Create a copy of the struct to avoid mutating the original
|
||||
newStruct := reflect.New(elem.Type())
|
||||
newStruct.Elem().Set(elem)
|
||||
if err := processStructFields(ctx, newStruct.Elem(), opts); err != nil {
|
||||
return err
|
||||
}
|
||||
field.Set(newStruct)
|
||||
|
||||
case reflect.Map:
|
||||
if elem.IsNil() {
|
||||
return nil
|
||||
}
|
||||
processed, err := processMapWithOpts(ctx, elem, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Create a new pointer to the processed map
|
||||
newMap := reflect.New(elem.Type())
|
||||
newMap.Elem().Set(processed)
|
||||
field.Set(newMap)
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
if elem.IsNil() {
|
||||
return nil
|
||||
}
|
||||
// Create a new slice with new backing array to avoid mutating original
|
||||
newSlice := reflect.MakeSlice(elem.Type(), elem.Len(), elem.Cap())
|
||||
reflect.Copy(newSlice, elem)
|
||||
if err := processSliceWithOpts(ctx, newSlice, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
// Create new pointer to the new slice
|
||||
newPtr := reflect.New(elem.Type())
|
||||
newPtr.Elem().Set(newSlice)
|
||||
field.Set(newPtr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processSliceWithOpts(ctx context.Context, v reflect.Value, opts *EvalOptions) error {
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
elem := v.Index(i)
|
||||
if !elem.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// nolint:exhaustive
|
||||
switch elem.Kind() {
|
||||
case reflect.String:
|
||||
value := elem.String()
|
||||
value = expandVariables(ctx, value, opts)
|
||||
if opts.Substitute {
|
||||
var err error
|
||||
value, err = substituteCommands(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opts.ExpandEnv {
|
||||
value = os.ExpandEnv(value)
|
||||
}
|
||||
elem.SetString(value)
|
||||
|
||||
case reflect.Struct:
|
||||
if err := processStructFields(ctx, elem, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
if elem.IsNil() {
|
||||
continue
|
||||
}
|
||||
processed, err := processMapWithOpts(ctx, elem, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
elem.Set(processed)
|
||||
|
||||
case reflect.Ptr:
|
||||
if elem.IsNil() {
|
||||
continue
|
||||
}
|
||||
if err := processPointerField(ctx, elem, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@ -3,7 +3,6 @@ package cmdutil
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
@ -85,13 +84,12 @@ func TestEvalStringFields(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
got, err := EvalStringFields(ctx, tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SubstituteStringFields() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("SubstituteStringFields() = %+v, want %+v", got, tt.want)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -107,12 +105,10 @@ func TestEvalStringFields_AnonymousStruct(t *testing.T) {
|
||||
require.Equal(t, "hello", obj.Field)
|
||||
}
|
||||
|
||||
func TestSubstituteStringFields_NonStruct(t *testing.T) {
|
||||
func TestEvalStringFields_NonStruct(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, err := EvalStringFields(ctx, "not a struct")
|
||||
if err == nil {
|
||||
t.Error("SubstituteStringFields() should return error for non-struct input")
|
||||
}
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEvalStringFields_NestedStructs(t *testing.T) {
|
||||
@ -160,13 +156,8 @@ func TestEvalStringFields_NestedStructs(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
got, err := EvalStringFields(ctx, input)
|
||||
if err != nil {
|
||||
t.Fatalf("SubstituteStringFields() error = %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("SubstituteStringFields() = %+v, want %+v", got, want)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestEvalStringFields_EmptyStruct(t *testing.T) {
|
||||
@ -175,13 +166,541 @@ func TestEvalStringFields_EmptyStruct(t *testing.T) {
|
||||
input := Empty{}
|
||||
ctx := context.Background()
|
||||
got, err := EvalStringFields(ctx, input)
|
||||
if err != nil {
|
||||
t.Fatalf("SubstituteStringFields() error = %v", err)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, input, got)
|
||||
}
|
||||
|
||||
func TestEvalStringFields_PointerFields(t *testing.T) {
|
||||
_ = os.Setenv("PTR_VAR", "pointer_value")
|
||||
defer func() {
|
||||
_ = os.Unsetenv("PTR_VAR")
|
||||
}()
|
||||
|
||||
type PointerNested struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, input) {
|
||||
t.Errorf("SubstituteStringFields() = %+v, want %+v", got, input)
|
||||
type PointerStruct struct {
|
||||
Token *string
|
||||
Nested *PointerNested
|
||||
Items []*PointerNested
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
token := "$PTR_VAR"
|
||||
input := PointerStruct{
|
||||
Token: &token,
|
||||
Nested: &PointerNested{Value: "${PTR_VAR}"},
|
||||
Items: []*PointerNested{{
|
||||
Value: "$PTR_VAR",
|
||||
}},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result.Token)
|
||||
assert.Equal(t, "pointer_value", *result.Token)
|
||||
require.NotNil(t, result.Nested)
|
||||
assert.Equal(t, "pointer_value", result.Nested.Value)
|
||||
require.Len(t, result.Items, 1)
|
||||
assert.Equal(t, "pointer_value", result.Items[0].Value)
|
||||
}
|
||||
|
||||
func TestEvalStringFields_NilPointerFields(t *testing.T) {
|
||||
t.Setenv("NIL_TEST_VAR", "value")
|
||||
|
||||
type Nested struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
type StructWithNilPointers struct {
|
||||
NilString *string
|
||||
NilStruct *Nested
|
||||
NilMap *map[string]string
|
||||
NilSlice *[]string
|
||||
// Non-nil field to verify processing continues
|
||||
Regular string
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
input := StructWithNilPointers{
|
||||
NilString: nil,
|
||||
NilStruct: nil,
|
||||
NilMap: nil,
|
||||
NilSlice: nil,
|
||||
Regular: "$NIL_TEST_VAR",
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result.NilString)
|
||||
assert.Nil(t, result.NilStruct)
|
||||
assert.Nil(t, result.NilMap)
|
||||
assert.Nil(t, result.NilSlice)
|
||||
assert.Equal(t, "value", result.Regular)
|
||||
}
|
||||
|
||||
func TestEvalStringFields_PointerToMap(t *testing.T) {
|
||||
t.Setenv("MAP_PTR_VAR", "map_ptr_value")
|
||||
|
||||
type StructWithPtrMap struct {
|
||||
Config *map[string]string
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("PointerToMapWithEnvVars", func(t *testing.T) {
|
||||
mapVal := map[string]string{
|
||||
"key1": "$MAP_PTR_VAR",
|
||||
"key2": "${MAP_PTR_VAR}",
|
||||
}
|
||||
input := StructWithPtrMap{Config: &mapVal}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result.Config)
|
||||
assert.Equal(t, "map_ptr_value", (*result.Config)["key1"])
|
||||
assert.Equal(t, "map_ptr_value", (*result.Config)["key2"])
|
||||
})
|
||||
|
||||
t.Run("PointerToNilMap", func(t *testing.T) {
|
||||
input := StructWithPtrMap{Config: nil}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result.Config)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalStringFields_PointerToSlice(t *testing.T) {
|
||||
t.Setenv("SLICE_PTR_VAR", "slice_ptr_value")
|
||||
|
||||
type Nested struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
type StructWithPtrSlice struct {
|
||||
Items *[]string
|
||||
Structs *[]*Nested
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("PointerToSliceOfStrings", func(t *testing.T) {
|
||||
items := []string{"$SLICE_PTR_VAR", "${SLICE_PTR_VAR}"}
|
||||
input := StructWithPtrSlice{Items: &items}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result.Items)
|
||||
require.Len(t, *result.Items, 2)
|
||||
assert.Equal(t, "slice_ptr_value", (*result.Items)[0])
|
||||
assert.Equal(t, "slice_ptr_value", (*result.Items)[1])
|
||||
})
|
||||
|
||||
t.Run("PointerToSliceOfStructPointers", func(t *testing.T) {
|
||||
structs := []*Nested{
|
||||
{Value: "$SLICE_PTR_VAR"},
|
||||
{Value: "${SLICE_PTR_VAR}"},
|
||||
}
|
||||
input := StructWithPtrSlice{Structs: &structs}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result.Structs)
|
||||
require.Len(t, *result.Structs, 2)
|
||||
assert.Equal(t, "slice_ptr_value", (*result.Structs)[0].Value)
|
||||
assert.Equal(t, "slice_ptr_value", (*result.Structs)[1].Value)
|
||||
})
|
||||
|
||||
t.Run("PointerToNilSlice", func(t *testing.T) {
|
||||
input := StructWithPtrSlice{Items: nil}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result.Items)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalStringFields_SliceOfStrings(t *testing.T) {
|
||||
t.Setenv("SLICE_STR_VAR", "slice_str_value")
|
||||
|
||||
type StructWithStringSlice struct {
|
||||
Tags []string
|
||||
Labels []string
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SliceWithEnvVars", func(t *testing.T) {
|
||||
input := StructWithStringSlice{
|
||||
Tags: []string{"$SLICE_STR_VAR", "${SLICE_STR_VAR}", "plain"},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Tags, 3)
|
||||
assert.Equal(t, "slice_str_value", result.Tags[0])
|
||||
assert.Equal(t, "slice_str_value", result.Tags[1])
|
||||
assert.Equal(t, "plain", result.Tags[2])
|
||||
})
|
||||
|
||||
t.Run("EmptySlice", func(t *testing.T) {
|
||||
input := StructWithStringSlice{
|
||||
Tags: []string{},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("NilSlice", func(t *testing.T) {
|
||||
input := StructWithStringSlice{
|
||||
Tags: nil,
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("SliceWithCommandSubstitution", func(t *testing.T) {
|
||||
input := StructWithStringSlice{
|
||||
Tags: []string{"`echo hello`", "`echo world`"},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Tags, 2)
|
||||
assert.Equal(t, "hello", result.Tags[0])
|
||||
assert.Equal(t, "world", result.Tags[1])
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalStringFields_SliceOfStructs(t *testing.T) {
|
||||
t.Setenv("STRUCT_SLICE_VAR", "struct_slice_value")
|
||||
|
||||
type Item struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
type StructWithStructSlice struct {
|
||||
Items []Item
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SliceOfStructsWithEnvVars", func(t *testing.T) {
|
||||
input := StructWithStructSlice{
|
||||
Items: []Item{
|
||||
{Name: "$STRUCT_SLICE_VAR", Value: "plain"},
|
||||
{Name: "plain", Value: "${STRUCT_SLICE_VAR}"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Items, 2)
|
||||
assert.Equal(t, "struct_slice_value", result.Items[0].Name)
|
||||
assert.Equal(t, "plain", result.Items[0].Value)
|
||||
assert.Equal(t, "plain", result.Items[1].Name)
|
||||
assert.Equal(t, "struct_slice_value", result.Items[1].Value)
|
||||
})
|
||||
|
||||
t.Run("EmptySliceOfStructs", func(t *testing.T) {
|
||||
input := StructWithStructSlice{
|
||||
Items: []Item{},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result.Items)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalStringFields_SliceOfMaps(t *testing.T) {
|
||||
t.Setenv("MAP_SLICE_VAR", "map_slice_value")
|
||||
|
||||
type StructWithMapSlice struct {
|
||||
Configs []map[string]string
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SliceOfMapsWithEnvVars", func(t *testing.T) {
|
||||
input := StructWithMapSlice{
|
||||
Configs: []map[string]string{
|
||||
{"key": "$MAP_SLICE_VAR"},
|
||||
{"key": "${MAP_SLICE_VAR}"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Configs, 2)
|
||||
assert.Equal(t, "map_slice_value", result.Configs[0]["key"])
|
||||
assert.Equal(t, "map_slice_value", result.Configs[1]["key"])
|
||||
})
|
||||
|
||||
t.Run("SliceWithNilMapElement", func(t *testing.T) {
|
||||
input := StructWithMapSlice{
|
||||
Configs: []map[string]string{
|
||||
{"key": "$MAP_SLICE_VAR"},
|
||||
nil,
|
||||
{"key": "plain"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Configs, 3)
|
||||
assert.Equal(t, "map_slice_value", result.Configs[0]["key"])
|
||||
assert.Nil(t, result.Configs[1])
|
||||
assert.Equal(t, "plain", result.Configs[2]["key"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalStringFields_SliceWithNilPointers(t *testing.T) {
|
||||
t.Setenv("NIL_PTR_SLICE_VAR", "nil_ptr_slice_value")
|
||||
|
||||
type Nested struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
type StructWithPointerSlice struct {
|
||||
StringPtrs []*string
|
||||
StructPtrs []*Nested
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SliceWithMixedNilAndNonNilStringPointers", func(t *testing.T) {
|
||||
val1 := "$NIL_PTR_SLICE_VAR"
|
||||
val2 := "${NIL_PTR_SLICE_VAR}"
|
||||
input := StructWithPointerSlice{
|
||||
StringPtrs: []*string{&val1, nil, &val2},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.StringPtrs, 3)
|
||||
require.NotNil(t, result.StringPtrs[0])
|
||||
assert.Equal(t, "nil_ptr_slice_value", *result.StringPtrs[0])
|
||||
assert.Nil(t, result.StringPtrs[1])
|
||||
require.NotNil(t, result.StringPtrs[2])
|
||||
assert.Equal(t, "nil_ptr_slice_value", *result.StringPtrs[2])
|
||||
})
|
||||
|
||||
t.Run("SliceWithMixedNilAndNonNilStructPointers", func(t *testing.T) {
|
||||
input := StructWithPointerSlice{
|
||||
StructPtrs: []*Nested{
|
||||
{Value: "$NIL_PTR_SLICE_VAR"},
|
||||
nil,
|
||||
{Value: "${NIL_PTR_SLICE_VAR}"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := EvalStringFields(ctx, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.StructPtrs, 3)
|
||||
require.NotNil(t, result.StructPtrs[0])
|
||||
assert.Equal(t, "nil_ptr_slice_value", result.StructPtrs[0].Value)
|
||||
assert.Nil(t, result.StructPtrs[1])
|
||||
require.NotNil(t, result.StructPtrs[2])
|
||||
assert.Equal(t, "nil_ptr_slice_value", result.StructPtrs[2].Value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalStringFields_PointerMutation(t *testing.T) {
|
||||
t.Setenv("MUT_VAR", "expanded")
|
||||
|
||||
t.Run("PointerToString", func(t *testing.T) {
|
||||
type S struct {
|
||||
Token *string
|
||||
}
|
||||
|
||||
original := "$MUT_VAR"
|
||||
input := S{Token: &original}
|
||||
|
||||
result, err := EvalStringFields(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check result is correct
|
||||
assert.Equal(t, "expanded", *result.Token)
|
||||
|
||||
// Check original was NOT mutated
|
||||
assert.Equal(t, "$MUT_VAR", original, "BUG: original variable was mutated")
|
||||
assert.Equal(t, "$MUT_VAR", *input.Token, "BUG: input struct's pointer target was mutated")
|
||||
})
|
||||
|
||||
t.Run("SliceOfStringPointers", func(t *testing.T) {
|
||||
type S struct {
|
||||
Items []*string
|
||||
}
|
||||
|
||||
val1 := "$MUT_VAR"
|
||||
val2 := "${MUT_VAR}"
|
||||
input := S{Items: []*string{&val1, &val2}}
|
||||
|
||||
result, err := EvalStringFields(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check result is correct
|
||||
require.Len(t, result.Items, 2)
|
||||
assert.Equal(t, "expanded", *result.Items[0])
|
||||
assert.Equal(t, "expanded", *result.Items[1])
|
||||
|
||||
// Check originals were NOT mutated
|
||||
assert.Equal(t, "$MUT_VAR", val1, "BUG: val1 was mutated")
|
||||
assert.Equal(t, "${MUT_VAR}", val2, "BUG: val2 was mutated")
|
||||
assert.Equal(t, "$MUT_VAR", *input.Items[0], "BUG: input.Items[0] target was mutated")
|
||||
assert.Equal(t, "${MUT_VAR}", *input.Items[1], "BUG: input.Items[1] target was mutated")
|
||||
})
|
||||
|
||||
t.Run("PointerToStruct", func(t *testing.T) {
|
||||
type Nested struct {
|
||||
Value string
|
||||
}
|
||||
type S struct {
|
||||
Nested *Nested
|
||||
}
|
||||
|
||||
input := S{Nested: &Nested{Value: "$MUT_VAR"}}
|
||||
originalValue := input.Nested.Value
|
||||
|
||||
result, err := EvalStringFields(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check result is correct
|
||||
assert.Equal(t, "expanded", result.Nested.Value)
|
||||
|
||||
// Check original was NOT mutated
|
||||
assert.Equal(t, "$MUT_VAR", originalValue, "BUG: original nested value was mutated")
|
||||
assert.Equal(t, "$MUT_VAR", input.Nested.Value, "BUG: input.Nested.Value was mutated")
|
||||
})
|
||||
|
||||
t.Run("PointerToSlice", func(t *testing.T) {
|
||||
type S struct {
|
||||
Items *[]string
|
||||
}
|
||||
|
||||
items := []string{"$MUT_VAR", "${MUT_VAR}"}
|
||||
input := S{Items: &items}
|
||||
|
||||
result, err := EvalStringFields(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check result is correct
|
||||
require.NotNil(t, result.Items)
|
||||
assert.Equal(t, "expanded", (*result.Items)[0])
|
||||
assert.Equal(t, "expanded", (*result.Items)[1])
|
||||
|
||||
// Check original was NOT mutated
|
||||
assert.Equal(t, "$MUT_VAR", items[0], "BUG: items[0] was mutated")
|
||||
assert.Equal(t, "${MUT_VAR}", items[1], "BUG: items[1] was mutated")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalStringFields_PointerFieldErrors(t *testing.T) {
|
||||
type Nested struct {
|
||||
Command string
|
||||
}
|
||||
|
||||
type StructWithPointerErrors struct {
|
||||
StringPtr *string
|
||||
StructPtr *Nested
|
||||
MapPtr *map[string]string
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("PointerToStringWithInvalidCommand", func(t *testing.T) {
|
||||
invalidCmd := "`invalid_command_xyz123`"
|
||||
input := StructWithPointerErrors{StringPtr: &invalidCmd}
|
||||
|
||||
_, err := EvalStringFields(ctx, input)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("PointerToStructWithInvalidNestedCommand", func(t *testing.T) {
|
||||
input := StructWithPointerErrors{
|
||||
StructPtr: &Nested{Command: "`invalid_command_xyz123`"},
|
||||
}
|
||||
|
||||
_, err := EvalStringFields(ctx, input)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("PointerToMapWithInvalidCommand", func(t *testing.T) {
|
||||
mapVal := map[string]string{
|
||||
"key": "`invalid_command_xyz123`",
|
||||
}
|
||||
input := StructWithPointerErrors{MapPtr: &mapVal}
|
||||
|
||||
_, err := EvalStringFields(ctx, input)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalStringFields_SliceFieldErrors(t *testing.T) {
|
||||
type Nested struct {
|
||||
Command string
|
||||
}
|
||||
|
||||
type StructWithSliceErrors struct {
|
||||
Strings []string
|
||||
Structs []Nested
|
||||
StringPtrs []*string
|
||||
Maps []map[string]string
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SliceStringWithInvalidCommand", func(t *testing.T) {
|
||||
input := StructWithSliceErrors{
|
||||
Strings: []string{"valid", "`invalid_command_xyz123`"},
|
||||
}
|
||||
|
||||
_, err := EvalStringFields(ctx, input)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("SliceStructWithInvalidNestedCommand", func(t *testing.T) {
|
||||
input := StructWithSliceErrors{
|
||||
Structs: []Nested{
|
||||
{Command: "valid"},
|
||||
{Command: "`invalid_command_xyz123`"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := EvalStringFields(ctx, input)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("SlicePointerWithInvalidCommand", func(t *testing.T) {
|
||||
valid := "valid"
|
||||
invalid := "`invalid_command_xyz123`"
|
||||
input := StructWithSliceErrors{
|
||||
StringPtrs: []*string{&valid, &invalid},
|
||||
}
|
||||
|
||||
_, err := EvalStringFields(ctx, input)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("SliceMapWithInvalidCommand", func(t *testing.T) {
|
||||
input := StructWithSliceErrors{
|
||||
Maps: []map[string]string{
|
||||
{"key": "valid"},
|
||||
{"key": "`invalid_command_xyz123`"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := EvalStringFields(ctx, input)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestReplaceVars(t *testing.T) {
|
||||
|
||||
43
internal/common/cmdutil/shell.go
Normal file
43
internal/common/cmdutil/shell.go
Normal file
@ -0,0 +1,43 @@
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ShellQuote escapes a string for use in a shell command.
|
||||
func ShellQuote(s string) string {
|
||||
if s == "" {
|
||||
return "''"
|
||||
}
|
||||
|
||||
// Use a conservative set of safe characters:
|
||||
// Alphanumeric, hyphen, underscore, dot, and slash.
|
||||
// We only consider ASCII alphanumeric as safe to avoid locale-dependent behavior.
|
||||
safe := true
|
||||
for i := 0; i < len(s); i++ {
|
||||
b := s[i]
|
||||
if (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') ||
|
||||
b == '-' || b == '_' || b == '.' || b == '/' {
|
||||
continue
|
||||
}
|
||||
safe = false
|
||||
break
|
||||
}
|
||||
if safe {
|
||||
return s
|
||||
}
|
||||
|
||||
// Wrap in single quotes and escape any internal single quotes.
|
||||
// This is the most robust way to escape for POSIX-compliant shells.
|
||||
// 'user's file' -> 'user'\''s file'
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
// ShellQuoteArgs escapes a slice of strings for use in a shell command.
|
||||
func ShellQuoteArgs(args []string) string {
|
||||
quoted := make([]string, len(args))
|
||||
for i, arg := range args {
|
||||
quoted[i] = ShellQuote(arg)
|
||||
}
|
||||
return strings.Join(quoted, " ")
|
||||
}
|
||||
265
internal/common/cmdutil/shell_test.go
Normal file
265
internal/common/cmdutil/shell_test.go
Normal file
@ -0,0 +1,265 @@
|
||||
package cmdutil
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShellQuote(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: "''",
|
||||
},
|
||||
{
|
||||
name: "Safe alphanumeric",
|
||||
input: "abcXYZ123",
|
||||
expected: "abcXYZ123",
|
||||
},
|
||||
{
|
||||
name: "Safe special characters",
|
||||
input: "-_./",
|
||||
expected: "-_./",
|
||||
},
|
||||
{
|
||||
name: "String with space",
|
||||
input: "hello world",
|
||||
expected: "'hello world'",
|
||||
},
|
||||
{
|
||||
name: "String with single quote",
|
||||
input: "user's file",
|
||||
expected: "'user'\\''s file'",
|
||||
},
|
||||
{
|
||||
name: "String with multiple single quotes",
|
||||
input: "a'b'c",
|
||||
expected: "'a'\\''b'\\''c'",
|
||||
},
|
||||
{
|
||||
name: "String with double quote",
|
||||
input: `"quoted"`,
|
||||
expected: `'"quoted"'`,
|
||||
},
|
||||
{
|
||||
name: "String with dollar sign",
|
||||
input: "$VAR",
|
||||
expected: "'$VAR'",
|
||||
},
|
||||
{
|
||||
name: "String with asterisk",
|
||||
input: "*.txt",
|
||||
expected: "'*.txt'",
|
||||
},
|
||||
{
|
||||
name: "String with backtick",
|
||||
input: "`date`",
|
||||
expected: "'`date`'",
|
||||
},
|
||||
{
|
||||
name: "String with semicolon",
|
||||
input: "ls; rm -rf /",
|
||||
expected: "'ls; rm -rf /'",
|
||||
},
|
||||
{
|
||||
name: "String with ampersand",
|
||||
input: "run &",
|
||||
expected: "'run &'",
|
||||
},
|
||||
{
|
||||
name: "String with pipe",
|
||||
input: "a | b",
|
||||
expected: "'a | b'",
|
||||
},
|
||||
{
|
||||
name: "String with parentheses",
|
||||
input: "(subshell)",
|
||||
expected: "'(subshell)'",
|
||||
},
|
||||
{
|
||||
name: "String with brackets",
|
||||
input: "[abc]",
|
||||
expected: "'[abc]'",
|
||||
},
|
||||
{
|
||||
name: "String with braces",
|
||||
input: "{1..10}",
|
||||
expected: "'{1..10}'",
|
||||
},
|
||||
{
|
||||
name: "String with redirection",
|
||||
input: "> output.txt",
|
||||
expected: "'> output.txt'",
|
||||
},
|
||||
{
|
||||
name: "String with backslash",
|
||||
input: "path\\to\\file",
|
||||
expected: "'path\\to\\file'",
|
||||
},
|
||||
{
|
||||
name: "String with newline",
|
||||
input: "line1\nline2",
|
||||
expected: "'line1\nline2'",
|
||||
},
|
||||
{
|
||||
name: "String with tab",
|
||||
input: "field1\tfield2",
|
||||
expected: "'field1\tfield2'",
|
||||
},
|
||||
{
|
||||
name: "Unicode string",
|
||||
input: "Hello 世界",
|
||||
expected: "'Hello 世界'",
|
||||
},
|
||||
{
|
||||
name: "Mixed single and double quotes",
|
||||
input: "It's a \"test\"",
|
||||
expected: "'It'\\''s a \"test\"'",
|
||||
},
|
||||
{
|
||||
name: "Only single quotes",
|
||||
input: "'''",
|
||||
expected: "''\\'''\\'''\\'''",
|
||||
},
|
||||
{
|
||||
name: "Backslashes and quotes",
|
||||
input: `\"'\`,
|
||||
expected: `'\"'\''\'`,
|
||||
},
|
||||
{
|
||||
name: "Non-printable characters",
|
||||
input: "\x01\x02\x03",
|
||||
expected: "'\x01\x02\x03'",
|
||||
},
|
||||
{
|
||||
name: "Terminal escape sequence",
|
||||
input: "\x1b[31mRed\x1b[0m",
|
||||
expected: "'\x1b[31mRed\x1b[0m'",
|
||||
},
|
||||
{
|
||||
name: "Ultra nasty mixed string",
|
||||
input: `'"` + "$; \\ \t\n\r\v\f!#%^&*()[]{}|<>?~",
|
||||
expected: `''\''"` + "$; \\ \t\n\r\v\f!#%^&*()[]{}|<>?~'",
|
||||
},
|
||||
{
|
||||
name: "Leading/trailing spaces",
|
||||
input: " spaced ",
|
||||
expected: "' spaced '",
|
||||
},
|
||||
{
|
||||
name: "Command injection attempt",
|
||||
input: "; rm -rf / ; #",
|
||||
expected: "'; rm -rf / ; #'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := ShellQuote(tt.input)
|
||||
assert.Equal(t, tt.expected, actual, "Input: %s", tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellQuote_ShellRoundTrip(t *testing.T) {
|
||||
if _, err := exec.LookPath("sh"); err != nil {
|
||||
t.Skip("sh not found in PATH")
|
||||
}
|
||||
|
||||
inputs := []string{
|
||||
"simple",
|
||||
"with space",
|
||||
"with 'single' quote",
|
||||
"with \"double\" quote",
|
||||
"with $dollar",
|
||||
"with `backtick`",
|
||||
"with \\backslash",
|
||||
"with \nnewline",
|
||||
"with \ttab",
|
||||
"with 世界 (unicode)",
|
||||
"with mixed \"' $`\\ \n\t chars",
|
||||
"",
|
||||
"'-'",
|
||||
"\"-\"",
|
||||
`'"` + "$; \\ \t\n\r\v\f!#%^&*()[]{}|<>?~", // Ultra nasty
|
||||
" leading and trailing spaces ",
|
||||
"!!!@@@###$$$%%%^^^&&&***((()))_++==--",
|
||||
}
|
||||
|
||||
for _, input := range inputs {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
quoted := ShellQuote(input)
|
||||
// Run sh -c 'printf %s <quoted>' and capture output
|
||||
// We use printf because echo might interpret sequences or add newlines
|
||||
cmd := exec.Command("sh", "-c", "printf %s "+quoted)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("sh failed for input %q: %v\nOutput: %s", input, err, string(output))
|
||||
}
|
||||
assert.Equal(t, input, string(output), "Round-trip failed for input %q", input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellQuoteArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "No args",
|
||||
args: []string{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Single safe arg",
|
||||
args: []string{"ls"},
|
||||
expected: "ls",
|
||||
},
|
||||
{
|
||||
name: "Single unsafe arg",
|
||||
args: []string{"ls -l"},
|
||||
expected: "'ls -l'",
|
||||
},
|
||||
{
|
||||
name: "Multiple args",
|
||||
args: []string{"ls", "-l", "my file.txt"},
|
||||
expected: "ls -l 'my file.txt'",
|
||||
},
|
||||
{
|
||||
name: "Complex args",
|
||||
args: []string{"echo", "It's a beautiful day", "$HOME"},
|
||||
expected: "echo 'It'\\''s a beautiful day' '$HOME'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := ShellQuoteArgs(tt.args)
|
||||
assert.Equal(t, tt.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellQuote_RoundTrip(t *testing.T) {
|
||||
// Exhaustive list of characters to test
|
||||
chars := ""
|
||||
for i := 0; i < 256; i++ {
|
||||
chars += string(rune(i))
|
||||
}
|
||||
|
||||
quoted := ShellQuote(chars)
|
||||
// We don't have a parser here to verify, but we can at least ensure it's not empty and wrapped if needed.
|
||||
assert.NotEmpty(t, quoted)
|
||||
if len(chars) > 0 {
|
||||
assert.True(t, len(quoted) >= len(chars))
|
||||
}
|
||||
}
|
||||
@ -201,6 +201,7 @@ type PathsConfig struct {
|
||||
ProcDir string
|
||||
ServiceRegistryDir string // Directory for service registry files
|
||||
UsersDir string // Directory for user data (builtin auth)
|
||||
APIKeysDir string // Directory for API key data (builtin auth)
|
||||
ConfigFileUsed string // Path to the configuration file used to load settings
|
||||
}
|
||||
|
||||
|
||||
@ -337,4 +337,5 @@ func TestConfig_Validate(t *testing.T) {
|
||||
err := cfg.Validate()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@ -217,6 +217,7 @@ type PathsDef struct {
|
||||
ProcDir string `mapstructure:"procDir"`
|
||||
ServiceRegistryDir string `mapstructure:"serviceRegistryDir"`
|
||||
UsersDir string `mapstructure:"usersDir"`
|
||||
APIKeysDir string `mapstructure:"apiKeysDir"`
|
||||
}
|
||||
|
||||
// UIDef holds the user interface configuration settings.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@ -11,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/adrg/xdg"
|
||||
"github.com/dagu-org/dagu/internal/common/cmdutil"
|
||||
"github.com/dagu-org/dagu/internal/common/fileutil"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@ -172,6 +174,12 @@ func (l *ConfigLoader) Load() (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
expandedDef, err := cmdutil.EvalStringFields(context.Background(), def, cmdutil.WithoutSubstitute())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to expand config variables: %w", err)
|
||||
}
|
||||
def = expandedDef
|
||||
|
||||
// Build the final Config from the definition (including legacy fields and validations).
|
||||
cfg, err := l.buildConfig(def)
|
||||
if err != nil {
|
||||
@ -290,6 +298,7 @@ func (l *ConfigLoader) loadPathsConfig(cfg *Config, def Definition) {
|
||||
cfg.Paths.ProcDir = fileutil.ResolvePathOrBlank(def.Paths.ProcDir)
|
||||
cfg.Paths.ServiceRegistryDir = fileutil.ResolvePathOrBlank(def.Paths.ServiceRegistryDir)
|
||||
cfg.Paths.UsersDir = fileutil.ResolvePathOrBlank(def.Paths.UsersDir)
|
||||
cfg.Paths.APIKeysDir = fileutil.ResolvePathOrBlank(def.Paths.APIKeysDir)
|
||||
}
|
||||
}
|
||||
|
||||
@ -392,6 +401,13 @@ func (l *ConfigLoader) loadServerConfig(cfg *Config, def Definition) {
|
||||
l.warnings = append(l.warnings, fmt.Sprintf("Auth mode auto-detected as 'oidc' based on OIDC configuration (issuer: %s)", oidc.Issuer))
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if basic auth is configured with builtin auth mode (it will be ignored)
|
||||
if cfg.Server.Auth.Mode == AuthModeBuiltin {
|
||||
if cfg.Server.Auth.Basic.Username != "" || cfg.Server.Auth.Basic.Password != "" {
|
||||
l.warnings = append(l.warnings, "Basic auth configuration is ignored when auth mode is 'builtin'; use builtin auth's admin credentials instead")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set default token TTL if not specified
|
||||
@ -613,6 +629,9 @@ func (l *ConfigLoader) finalizePaths(cfg *Config) {
|
||||
if cfg.Paths.UsersDir == "" {
|
||||
cfg.Paths.UsersDir = filepath.Join(cfg.Paths.DataDir, "users")
|
||||
}
|
||||
if cfg.Paths.APIKeysDir == "" {
|
||||
cfg.Paths.APIKeysDir = filepath.Join(cfg.Paths.DataDir, "apikeys")
|
||||
}
|
||||
|
||||
if cfg.Paths.Executable == "" {
|
||||
if executable, err := os.Executable(); err == nil {
|
||||
|
||||
@ -200,7 +200,8 @@ func TestLoad_Env(t *testing.T) {
|
||||
ProcDir: filepath.Join(testPaths, "proc"),
|
||||
QueueDir: filepath.Join(testPaths, "queue"),
|
||||
ServiceRegistryDir: filepath.Join(testPaths, "service-registry"),
|
||||
UsersDir: filepath.Join(testPaths, "data", "users"), // Derived from DataDir
|
||||
UsersDir: filepath.Join(testPaths, "data", "users"), // Derived from DataDir
|
||||
APIKeysDir: filepath.Join(testPaths, "data", "apikeys"), // Derived from DataDir
|
||||
},
|
||||
UI: UI{
|
||||
LogEncodingCharset: "iso-8859-1",
|
||||
@ -432,6 +433,7 @@ scheduler:
|
||||
QueueDir: "/var/dagu/data/queue",
|
||||
ServiceRegistryDir: "/var/dagu/data/service-registry",
|
||||
UsersDir: "/var/dagu/data/users",
|
||||
APIKeysDir: "/var/dagu/data/apikeys",
|
||||
},
|
||||
UI: UI{
|
||||
LogEncodingCharset: "iso-8859-1",
|
||||
@ -606,6 +608,26 @@ scheduler:
|
||||
assert.Contains(t, cfg.Warnings[1], "Invalid scheduler.lockRetryInterval")
|
||||
assert.Contains(t, cfg.Warnings[2], "Invalid scheduler.zombieDetectionInterval")
|
||||
})
|
||||
|
||||
t.Run("BuiltinAuthWithBasicAuthWarning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := loadWithEnv(t, `
|
||||
auth:
|
||||
mode: builtin
|
||||
builtin:
|
||||
admin:
|
||||
username: admin
|
||||
token:
|
||||
secret: test-secret
|
||||
basic:
|
||||
username: basicuser
|
||||
password: basicpass
|
||||
paths:
|
||||
usersDir: /tmp/users
|
||||
`, nil)
|
||||
require.Len(t, cfg.Warnings, 1)
|
||||
assert.Contains(t, cfg.Warnings[0], "Basic auth configuration is ignored when auth mode is 'builtin'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoad_LegacyEnv(t *testing.T) {
|
||||
@ -731,6 +753,19 @@ func loadWithEnv(t *testing.T, yaml string, env map[string]string) *Config {
|
||||
return loadFromYAML(t, yaml)
|
||||
}
|
||||
|
||||
func unsetEnv(t *testing.T, key string) {
|
||||
t.Helper()
|
||||
original, existed := os.LookupEnv(key)
|
||||
os.Unsetenv(key)
|
||||
t.Cleanup(func() {
|
||||
if existed {
|
||||
os.Setenv(key, original)
|
||||
return
|
||||
}
|
||||
os.Unsetenv(key)
|
||||
})
|
||||
}
|
||||
|
||||
// loadFromYAML loads config from YAML string
|
||||
func loadFromYAML(t *testing.T, yaml string) *Config {
|
||||
t.Helper()
|
||||
@ -963,3 +998,19 @@ auth:
|
||||
assert.Equal(t, 24*time.Hour, cfg.Server.Auth.Builtin.Token.TTL)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoad_AuthTokenEnvExpansion(t *testing.T) {
|
||||
t.Parallel()
|
||||
unsetEnv(t, "DAGU_AUTH_TOKEN")
|
||||
unsetEnv(t, "DAGU_AUTHTOKEN")
|
||||
|
||||
cfg := loadWithEnv(t, `
|
||||
auth:
|
||||
token:
|
||||
value: "${AUTH_TOKEN}"
|
||||
`, map[string]string{
|
||||
"AUTH_TOKEN": "env-token",
|
||||
})
|
||||
|
||||
assert.Equal(t, "env-token", cfg.Server.Auth.Token.Value)
|
||||
}
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
package config
|
||||
|
||||
import "sync"
|
||||
|
||||
// globalViperMu serializes all access to the shared viper instance used across the application.
|
||||
// Uses RWMutex to allow concurrent reads while serializing writes.
|
||||
var globalViperMu sync.RWMutex
|
||||
|
||||
// lockViper acquires the global viper write lock (private).
|
||||
func lockViper() {
|
||||
globalViperMu.Lock()
|
||||
}
|
||||
|
||||
// unlockViper releases the global viper write lock (private).
|
||||
func unlockViper() {
|
||||
globalViperMu.Unlock()
|
||||
}
|
||||
|
||||
// WithViperLock runs the provided function while holding the global viper write lock.
|
||||
// Use this for any operations that modify viper state (Set, BindPFlag, ReadInConfig, etc.).
|
||||
func WithViperLock(fn func()) {
|
||||
lockViper()
|
||||
defer unlockViper()
|
||||
fn()
|
||||
}
|
||||
|
||||
// WithViperRLock runs the provided function while holding the global viper read lock.
|
||||
// Use this for read-only operations (Get, IsSet, etc.) to allow concurrent access.
|
||||
func WithViperRLock(fn func()) {
|
||||
globalViperMu.RLock()
|
||||
defer globalViperMu.RUnlock()
|
||||
fn()
|
||||
}
|
||||
@ -221,6 +221,8 @@ func TestRegistry_Providers(t *testing.T) {
|
||||
assert.Contains(t, providers, "custom")
|
||||
}
|
||||
|
||||
var _ Resolver = (*mockResolver)(nil)
|
||||
|
||||
// mockResolver is a test double for the Resolver interface
|
||||
type mockResolver struct {
|
||||
mockName string
|
||||
|
||||
@ -68,6 +68,38 @@ func RemoveQuotes(s string) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// ScreamingSnakeToCamel converts a SCREAMING_SNAKE_CASE string to camelCase.
|
||||
// Example: TOTAL_COUNT -> totalCount, MY_VAR -> myVar, FOO -> foo
|
||||
func ScreamingSnakeToCamel(s string) string {
|
||||
parts := strings.Split(s, "_")
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
isFirst := true
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(part)
|
||||
if isFirst {
|
||||
result.WriteString(lower)
|
||||
isFirst = false
|
||||
} else {
|
||||
// Capitalize first letter of subsequent parts
|
||||
if len(lower) > 0 {
|
||||
result.WriteString(strings.ToUpper(lower[:1]))
|
||||
if len(lower) > 1 {
|
||||
result.WriteString(lower[1:])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// KebabToCamel converts a kebab-case string to camelCase.
|
||||
func KebabToCamel(s string) string {
|
||||
parts := strings.Split(s, "-")
|
||||
|
||||
@ -141,6 +141,33 @@ func TestKebabToCamel(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestScreamingSnakeToCamel(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"TOTAL_COUNT", "totalCount"},
|
||||
{"MY_VAR", "myVar"},
|
||||
{"FOO", "foo"},
|
||||
{"FOO_BAR_BAZ", "fooBarBaz"},
|
||||
{"", ""},
|
||||
{"_LEADING", "leading"},
|
||||
{"TRAILING_", "trailing"},
|
||||
{"DOUBLE__UNDERSCORE", "doubleUnderscore"},
|
||||
{"already_lower", "alreadyLower"},
|
||||
{"MiXeD_CaSe", "mixedCase"},
|
||||
{"A", "a"},
|
||||
{"A_B_C", "aBC"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := stringutil.ScreamingSnakeToCamel(tt.input)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsMultiLine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@ -16,6 +16,8 @@ import (
|
||||
"github.com/dagu-org/dagu/internal/core/spec"
|
||||
)
|
||||
|
||||
var _ execution.DAGStore = (*mockDAGStore)(nil)
|
||||
|
||||
// Mock implementations
|
||||
type mockDAGStore struct {
|
||||
mock.Mock
|
||||
@ -151,9 +153,12 @@ func (m *mockDAGRunStore) FindSubAttempt(ctx context.Context, dagRun execution.D
|
||||
return args.Get(0).(execution.DAGRunAttempt), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockDAGRunStore) RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int) error {
|
||||
args := m.Called(ctx, name, retentionDays)
|
||||
return args.Error(0)
|
||||
func (m *mockDAGRunStore) RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int, opts ...execution.RemoveOldDAGRunsOption) ([]string, error) {
|
||||
args := m.Called(ctx, name, retentionDays, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]string), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockDAGRunStore) RenameDAGRuns(ctx context.Context, oldName, newName string) error {
|
||||
@ -218,6 +223,8 @@ func (m *mockQueueStore) All(ctx context.Context) ([]execution.QueuedItemData, e
|
||||
return args.Get(0).([]execution.QueuedItemData), args.Error(1)
|
||||
}
|
||||
|
||||
var _ execution.ServiceRegistry = (*mockServiceRegistry)(nil)
|
||||
|
||||
type mockServiceRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
83
internal/core/capabilities.go
Normal file
83
internal/core/capabilities.go
Normal file
@ -0,0 +1,83 @@
|
||||
package core
|
||||
|
||||
// ExecutorCapabilities defines what an executor can do.
|
||||
type ExecutorCapabilities struct {
|
||||
// Command indicates whether the executor supports the command field.
|
||||
Command bool
|
||||
// MultipleCommands indicates whether the executor supports multiple commands.
|
||||
MultipleCommands bool
|
||||
// Script indicates whether the executor supports the script field.
|
||||
Script bool
|
||||
// Shell indicates whether the executor uses shell/shellArgs/shellPackages.
|
||||
Shell bool
|
||||
// Container indicates whether the executor supports step-level container config.
|
||||
Container bool
|
||||
// SubDAG indicates whether the executor can execute sub-DAGs.
|
||||
SubDAG bool
|
||||
// WorkerSelector indicates whether the executor supports worker selection.
|
||||
WorkerSelector bool
|
||||
}
|
||||
|
||||
// executorCapabilitiesRegistry is a typed registry of executor capabilities.
|
||||
type executorCapabilitiesRegistry struct {
|
||||
caps map[string]ExecutorCapabilities
|
||||
}
|
||||
|
||||
var executorCapabilities = executorCapabilitiesRegistry{
|
||||
caps: make(map[string]ExecutorCapabilities),
|
||||
}
|
||||
|
||||
// Register registers capabilities for an executor type.
|
||||
func (r *executorCapabilitiesRegistry) Register(executorType string, caps ExecutorCapabilities) {
|
||||
r.caps[executorType] = caps
|
||||
}
|
||||
|
||||
// Get returns capabilities for an executor type.
|
||||
// Returns an empty ExecutorCapabilities if not registered.
|
||||
func (r *executorCapabilitiesRegistry) Get(executorType string) ExecutorCapabilities {
|
||||
if caps, ok := r.caps[executorType]; ok {
|
||||
return caps
|
||||
}
|
||||
// Default: return all false (strict mode)
|
||||
return ExecutorCapabilities{}
|
||||
}
|
||||
|
||||
// RegisterExecutorCapabilities registers capabilities for an executor type.
|
||||
func RegisterExecutorCapabilities(executorType string, caps ExecutorCapabilities) {
|
||||
executorCapabilities.Register(executorType, caps)
|
||||
}
|
||||
|
||||
// SupportsCommand returns whether the executor type supports the command field.
|
||||
func SupportsCommand(executorType string) bool {
|
||||
return executorCapabilities.Get(executorType).Command
|
||||
}
|
||||
|
||||
// SupportsMultipleCommands returns whether the executor type supports multiple commands.
|
||||
func SupportsMultipleCommands(executorType string) bool {
|
||||
return executorCapabilities.Get(executorType).MultipleCommands
|
||||
}
|
||||
|
||||
// SupportsScript returns whether the executor type supports the script field.
|
||||
func SupportsScript(executorType string) bool {
|
||||
return executorCapabilities.Get(executorType).Script
|
||||
}
|
||||
|
||||
// SupportsShell returns whether the executor type uses shell configuration.
|
||||
func SupportsShell(executorType string) bool {
|
||||
return executorCapabilities.Get(executorType).Shell
|
||||
}
|
||||
|
||||
// SupportsContainer returns whether the executor type supports step-level container config.
|
||||
func SupportsContainer(executorType string) bool {
|
||||
return executorCapabilities.Get(executorType).Container
|
||||
}
|
||||
|
||||
// SupportsSubDAG returns whether the executor type can execute sub-DAGs.
|
||||
func SupportsSubDAG(executorType string) bool {
|
||||
return executorCapabilities.Get(executorType).SubDAG
|
||||
}
|
||||
|
||||
// SupportsWorkerSelector returns whether the executor type supports worker selection.
|
||||
func SupportsWorkerSelector(executorType string) bool {
|
||||
return executorCapabilities.Get(executorType).WorkerSelector
|
||||
}
|
||||
40
internal/core/capabilities_test.go
Normal file
40
internal/core/capabilities_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestExecutorCapabilities_Get(t *testing.T) {
|
||||
registry := &executorCapabilitiesRegistry{
|
||||
caps: make(map[string]ExecutorCapabilities),
|
||||
}
|
||||
|
||||
// Test case 1: Registered executor
|
||||
caps := ExecutorCapabilities{Command: true, MultipleCommands: true}
|
||||
registry.Register("test-executor", caps)
|
||||
assert.Equal(t, caps, registry.Get("test-executor"))
|
||||
|
||||
// Test case 2: Unregistered executor should return empty capabilities (strict default)
|
||||
assert.Equal(t, ExecutorCapabilities{}, registry.Get("unregistered"))
|
||||
}
|
||||
|
||||
func TestSupportsHelpers(t *testing.T) {
|
||||
// Register a test executor with specific capabilities
|
||||
caps := ExecutorCapabilities{
|
||||
Command: true,
|
||||
Script: false,
|
||||
WorkerSelector: true,
|
||||
}
|
||||
RegisterExecutorCapabilities("helper-test", caps)
|
||||
|
||||
assert.True(t, SupportsCommand("helper-test"))
|
||||
assert.False(t, SupportsScript("helper-test"))
|
||||
assert.True(t, SupportsWorkerSelector("helper-test"))
|
||||
|
||||
// Unregistered executor should return false for everything
|
||||
assert.False(t, SupportsCommand("unknown"))
|
||||
assert.False(t, SupportsScript("unknown"))
|
||||
assert.False(t, SupportsShell("unknown"))
|
||||
}
|
||||
@ -7,6 +7,8 @@ import (
|
||||
|
||||
// Container defines the container configuration for the DAG.
|
||||
type Container struct {
|
||||
// Name is the container name to use. If empty, Docker generates a random name.
|
||||
Name string `yaml:"name,omitempty"`
|
||||
// Image is the container image to use.
|
||||
Image string `yaml:"image,omitempty"`
|
||||
// PullPolicy is the policy to pull the image (e.g., "Always", "IfNotPresent").
|
||||
|
||||
@ -29,6 +29,36 @@ const (
|
||||
TypeAgent = "agent"
|
||||
)
|
||||
|
||||
// LogOutputMode represents the mode for log output handling.
|
||||
// It determines how stdout and stderr are written to log files.
|
||||
type LogOutputMode string
|
||||
|
||||
const (
|
||||
// LogOutputSeparate keeps stdout and stderr in separate files (.out and .err).
|
||||
// This is the default behavior for backward compatibility.
|
||||
LogOutputSeparate LogOutputMode = "separate"
|
||||
|
||||
// LogOutputMerged combines stdout and stderr into a single log file (.log).
|
||||
// Both streams are interleaved in the order they are written.
|
||||
LogOutputMerged LogOutputMode = "merged"
|
||||
)
|
||||
|
||||
// EffectiveLogOutput returns the effective log output mode for a step.
|
||||
// It resolves the inheritance chain: step-level overrides DAG-level,
|
||||
// and if neither is set, returns the default (LogOutputSeparate).
|
||||
func EffectiveLogOutput(dag *DAG, step *Step) LogOutputMode {
|
||||
// Step-level override takes precedence
|
||||
if step != nil && step.LogOutput != "" {
|
||||
return step.LogOutput
|
||||
}
|
||||
// Fall back to DAG-level setting
|
||||
if dag != nil && dag.LogOutput != "" {
|
||||
return dag.LogOutput
|
||||
}
|
||||
// Default to separate
|
||||
return LogOutputSeparate
|
||||
}
|
||||
|
||||
// DAG contains all information about a DAG.
|
||||
type DAG struct {
|
||||
// WorkingDir is the working directory to run the DAG.
|
||||
@ -71,6 +101,10 @@ type DAG struct {
|
||||
Env []string `json:"env,omitempty"`
|
||||
// LogDir is the directory where the logs are stored.
|
||||
LogDir string `json:"logDir,omitempty"`
|
||||
// LogOutput specifies how stdout and stderr are handled in log files.
|
||||
// Can be "separate" (default) for separate .out and .err files,
|
||||
// or "merged" for a single combined .log file.
|
||||
LogOutput LogOutputMode `json:"logOutput,omitempty"`
|
||||
// DefaultParams contains the default parameters to be passed to the DAG.
|
||||
DefaultParams string `json:"defaultParams,omitempty"`
|
||||
// Params contains the list of parameters to be passed to the DAG.
|
||||
|
||||
@ -26,7 +26,8 @@ func TestSockAddr(t *testing.T) {
|
||||
dag := &core.DAG{
|
||||
Location: "testdata/testDagVeryLongNameThatExceedsUnixSocketLengthMaximum-testDagVeryLongNameThatExceedsUnixSocketLengthMaximum.yml",
|
||||
}
|
||||
// 50 is the maximum length of a unix socket address
|
||||
// 50 is an application-imposed limit to keep socket names short and portable
|
||||
// (the system limit UNIX_PATH_MAX is typically 108 bytes on Linux)
|
||||
require.LessOrEqual(t, 50, len(dag.SockAddr("")))
|
||||
require.Equal(
|
||||
t,
|
||||
@ -305,45 +306,574 @@ func TestNextRun(t *testing.T) {
|
||||
require.Equal(t, expectedNext, nextRun)
|
||||
}
|
||||
|
||||
func TestAuthConfig(t *testing.T) {
|
||||
t.Run("AuthConfigFields", func(t *testing.T) {
|
||||
auth := &core.AuthConfig{
|
||||
Username: "test-user",
|
||||
Password: "test-pass",
|
||||
Auth: "dGVzdC11c2VyOnRlc3QtcGFzcw==", // base64("test-user:test-pass")
|
||||
}
|
||||
func TestEffectiveLogOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "test-user", auth.Username)
|
||||
assert.Equal(t, "test-pass", auth.Password)
|
||||
assert.Equal(t, "dGVzdC11c2VyOnRlc3QtcGFzcw==", auth.Auth)
|
||||
tests := []struct {
|
||||
name string
|
||||
dagLogOutput core.LogOutputMode
|
||||
stepLogOutput core.LogOutputMode
|
||||
expected core.LogOutputMode
|
||||
}{
|
||||
{
|
||||
name: "BothEmpty_ReturnsSeparate",
|
||||
dagLogOutput: "",
|
||||
stepLogOutput: "",
|
||||
expected: core.LogOutputSeparate,
|
||||
},
|
||||
{
|
||||
name: "DAGSeparate_StepEmpty_ReturnsSeparate",
|
||||
dagLogOutput: core.LogOutputSeparate,
|
||||
stepLogOutput: "",
|
||||
expected: core.LogOutputSeparate,
|
||||
},
|
||||
{
|
||||
name: "DAGMerged_StepEmpty_ReturnsMerged",
|
||||
dagLogOutput: core.LogOutputMerged,
|
||||
stepLogOutput: "",
|
||||
expected: core.LogOutputMerged,
|
||||
},
|
||||
{
|
||||
name: "DAGEmpty_StepSeparate_ReturnsSeparate",
|
||||
dagLogOutput: "",
|
||||
stepLogOutput: core.LogOutputSeparate,
|
||||
expected: core.LogOutputSeparate,
|
||||
},
|
||||
{
|
||||
name: "DAGEmpty_StepMerged_ReturnsMerged",
|
||||
dagLogOutput: "",
|
||||
stepLogOutput: core.LogOutputMerged,
|
||||
expected: core.LogOutputMerged,
|
||||
},
|
||||
{
|
||||
name: "DAGSeparate_StepMerged_StepOverrides",
|
||||
dagLogOutput: core.LogOutputSeparate,
|
||||
stepLogOutput: core.LogOutputMerged,
|
||||
expected: core.LogOutputMerged,
|
||||
},
|
||||
{
|
||||
name: "DAGMerged_StepSeparate_StepOverrides",
|
||||
dagLogOutput: core.LogOutputMerged,
|
||||
stepLogOutput: core.LogOutputSeparate,
|
||||
expected: core.LogOutputSeparate,
|
||||
},
|
||||
{
|
||||
name: "NilDAG_StepMerged_ReturnsMerged",
|
||||
dagLogOutput: "", // Will use nil DAG
|
||||
stepLogOutput: core.LogOutputMerged,
|
||||
expected: core.LogOutputMerged,
|
||||
},
|
||||
{
|
||||
name: "NilStep_DAGMerged_ReturnsMerged",
|
||||
dagLogOutput: core.LogOutputMerged,
|
||||
stepLogOutput: "", // Will use nil Step
|
||||
expected: core.LogOutputMerged,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var dag *core.DAG
|
||||
var step *core.Step
|
||||
|
||||
// Setup DAG
|
||||
if tt.name != "NilDAG_StepMerged_ReturnsMerged" {
|
||||
dag = &core.DAG{LogOutput: tt.dagLogOutput}
|
||||
}
|
||||
|
||||
// Setup Step
|
||||
if tt.name != "NilStep_DAGMerged_ReturnsMerged" {
|
||||
step = &core.Step{LogOutput: tt.stepLogOutput}
|
||||
}
|
||||
|
||||
result := core.EffectiveLogOutput(dag, step)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDAG_Validate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dag *core.DAG
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid DAG with name passes",
|
||||
dag: &core.DAG{
|
||||
Name: "test-dag",
|
||||
Steps: []core.Step{
|
||||
{Name: "step1"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty name fails",
|
||||
dag: &core.DAG{Name: ""},
|
||||
wantErr: true,
|
||||
errMsg: "DAG name is required",
|
||||
},
|
||||
{
|
||||
name: "valid dependencies pass",
|
||||
dag: &core.DAG{
|
||||
Name: "test-dag",
|
||||
Steps: []core.Step{
|
||||
{Name: "step1"},
|
||||
{Name: "step2", Depends: []string{"step1"}},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing dependency fails",
|
||||
dag: &core.DAG{
|
||||
Name: "test-dag",
|
||||
Steps: []core.Step{
|
||||
{Name: "step1"},
|
||||
{Name: "step2", Depends: []string{"nonexistent"}},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "non-existent step",
|
||||
},
|
||||
{
|
||||
name: "complex multi-level dependencies pass",
|
||||
dag: &core.DAG{
|
||||
Name: "test-dag",
|
||||
Steps: []core.Step{
|
||||
{Name: "step1"},
|
||||
{Name: "step2", Depends: []string{"step1"}},
|
||||
{Name: "step3", Depends: []string{"step1", "step2"}},
|
||||
{Name: "step4", Depends: []string{"step3"}},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "steps with no dependencies pass",
|
||||
dag: &core.DAG{
|
||||
Name: "test-dag",
|
||||
Steps: []core.Step{
|
||||
{Name: "step1"},
|
||||
{Name: "step2"},
|
||||
{Name: "step3"},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := tt.dag.Validate()
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDAG_HasTag(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tags []string
|
||||
search string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty tags, search for any returns false",
|
||||
tags: []string{},
|
||||
search: "test",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "has tag, search for it returns true",
|
||||
tags: []string{"production", "critical"},
|
||||
search: "production",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has tag, search for different returns false",
|
||||
tags: []string{"production", "critical"},
|
||||
search: "staging",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple tags, search for last one returns true",
|
||||
tags: []string{"a", "b", "c", "d"},
|
||||
search: "d",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "case sensitive check - exact match",
|
||||
tags: []string{"Production"},
|
||||
search: "Production",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "case sensitive check - different case returns false",
|
||||
tags: []string{"Production"},
|
||||
search: "production",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil tags returns false",
|
||||
tags: nil,
|
||||
search: "test",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Tags: tt.tags}
|
||||
result := dag.HasTag(tt.search)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDAG_ParamsMap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
params []string
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
name: "empty params returns empty map",
|
||||
params: []string{},
|
||||
expected: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "single param key=value",
|
||||
params: []string{"key=value"},
|
||||
expected: map[string]string{"key": "value"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
params: []string{"key1=value1", "key2=value2", "key3=value3"},
|
||||
expected: map[string]string{"key1": "value1", "key2": "value2", "key3": "value3"},
|
||||
},
|
||||
{
|
||||
name: "param with multiple equals - first splits",
|
||||
params: []string{"key=value=with=equals"},
|
||||
expected: map[string]string{"key": "value=with=equals"},
|
||||
},
|
||||
{
|
||||
name: "param without equals - excluded",
|
||||
params: []string{"noequals"},
|
||||
expected: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid params",
|
||||
params: []string{"valid=value", "invalid", "another=one"},
|
||||
expected: map[string]string{"valid": "value", "another": "one"},
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
params: []string{"key="},
|
||||
expected: map[string]string{"key": ""},
|
||||
},
|
||||
{
|
||||
name: "nil params returns empty map",
|
||||
params: nil,
|
||||
expected: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Params: tt.params}
|
||||
result := dag.ParamsMap()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDAG_ProcGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queue string
|
||||
dagName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "queue set returns queue",
|
||||
queue: "my-queue",
|
||||
dagName: "my-dag",
|
||||
expected: "my-queue",
|
||||
},
|
||||
{
|
||||
name: "queue empty returns dag name",
|
||||
queue: "",
|
||||
dagName: "my-dag",
|
||||
expected: "my-dag",
|
||||
},
|
||||
{
|
||||
name: "both empty returns empty string",
|
||||
queue: "",
|
||||
dagName: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Queue: tt.queue, Name: tt.dagName}
|
||||
result := dag.ProcGroup()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDAG_FileName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
location string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "location with .yaml extension",
|
||||
location: "/path/to/mydag.yaml",
|
||||
expected: "mydag",
|
||||
},
|
||||
{
|
||||
name: "location with .yml extension",
|
||||
location: "/path/to/mydag.yml",
|
||||
expected: "mydag",
|
||||
},
|
||||
{
|
||||
name: "location with no extension",
|
||||
location: "/path/to/mydag",
|
||||
expected: "mydag",
|
||||
},
|
||||
{
|
||||
name: "empty location returns empty string",
|
||||
location: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "just filename with yaml",
|
||||
location: "simple.yaml",
|
||||
expected: "simple",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Location: tt.location}
|
||||
result := dag.FileName()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDAG_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("full DAG formatted output", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{
|
||||
Name: "test-dag",
|
||||
Description: "A test DAG",
|
||||
Params: []string{"param1=value1", "param2=value2"},
|
||||
LogDir: "/var/log/dags",
|
||||
Steps: []core.Step{
|
||||
{Name: "step1"},
|
||||
{Name: "step2"},
|
||||
},
|
||||
}
|
||||
result := dag.String()
|
||||
|
||||
// Verify key fields are included
|
||||
assert.Contains(t, result, "test-dag")
|
||||
assert.Contains(t, result, "A test DAG")
|
||||
assert.Contains(t, result, "param1=value1")
|
||||
assert.Contains(t, result, "/var/log/dags")
|
||||
})
|
||||
|
||||
t.Run("minimal DAG basic output", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{
|
||||
Name: "minimal",
|
||||
}
|
||||
result := dag.String()
|
||||
assert.Contains(t, result, "minimal")
|
||||
assert.Contains(t, result, "{")
|
||||
assert.Contains(t, result, "}")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDAGRegistryAuths(t *testing.T) {
|
||||
t.Run("DAGWithRegistryAuths", func(t *testing.T) {
|
||||
func TestDAG_InitializeDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty DAG sets all defaults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{}
|
||||
core.InitializeDefaults(dag)
|
||||
|
||||
assert.Equal(t, core.TypeChain, dag.Type)
|
||||
assert.Equal(t, 30, dag.HistRetentionDays)
|
||||
assert.Equal(t, 5*time.Second, dag.MaxCleanUpTime)
|
||||
assert.Equal(t, 1, dag.MaxActiveRuns)
|
||||
assert.Equal(t, 1024*1024, dag.MaxOutputSize)
|
||||
})
|
||||
|
||||
t.Run("pre-existing Type not overwritten", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Type: core.TypeGraph}
|
||||
core.InitializeDefaults(dag)
|
||||
|
||||
assert.Equal(t, core.TypeGraph, dag.Type)
|
||||
})
|
||||
|
||||
t.Run("pre-existing HistRetentionDays not overwritten", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{HistRetentionDays: 90}
|
||||
core.InitializeDefaults(dag)
|
||||
|
||||
assert.Equal(t, 90, dag.HistRetentionDays)
|
||||
})
|
||||
|
||||
t.Run("pre-existing MaxActiveRuns not overwritten", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{MaxActiveRuns: 5}
|
||||
core.InitializeDefaults(dag)
|
||||
|
||||
assert.Equal(t, 5, dag.MaxActiveRuns)
|
||||
})
|
||||
|
||||
t.Run("negative MaxActiveRuns preserved (queueing disabled)", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{MaxActiveRuns: -1}
|
||||
core.InitializeDefaults(dag)
|
||||
|
||||
// Negative values mean queueing is disabled, should be preserved
|
||||
assert.Equal(t, -1, dag.MaxActiveRuns)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDAG_NextRun_Extended(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty schedule returns zero time", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Schedule: []core.Schedule{}}
|
||||
now := time.Now()
|
||||
result := dag.NextRun(now)
|
||||
assert.True(t, result.IsZero())
|
||||
})
|
||||
|
||||
t.Run("single schedule returns correct next time", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Schedule for every hour at minute 0
|
||||
parsed, err := cron.ParseStandard("0 * * * *")
|
||||
require.NoError(t, err)
|
||||
|
||||
dag := &core.DAG{
|
||||
Name: "test-dag",
|
||||
RegistryAuths: map[string]*core.AuthConfig{
|
||||
"docker.io": {
|
||||
Username: "docker-user",
|
||||
Password: "docker-pass",
|
||||
},
|
||||
"ghcr.io": {
|
||||
Username: "github-user",
|
||||
Password: "github-token",
|
||||
},
|
||||
Schedule: []core.Schedule{
|
||||
{Expression: "0 * * * *", Parsed: parsed},
|
||||
},
|
||||
}
|
||||
|
||||
assert.NotNil(t, dag.RegistryAuths)
|
||||
assert.Len(t, dag.RegistryAuths, 2)
|
||||
now := time.Date(2023, 10, 1, 12, 30, 0, 0, time.UTC)
|
||||
result := dag.NextRun(now)
|
||||
|
||||
dockerAuth, exists := dag.RegistryAuths["docker.io"]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "docker-user", dockerAuth.Username)
|
||||
// Should be the next hour
|
||||
expected := time.Date(2023, 10, 1, 13, 0, 0, 0, time.UTC)
|
||||
assert.Equal(t, expected, result)
|
||||
})
|
||||
|
||||
ghcrAuth, exists := dag.RegistryAuths["ghcr.io"]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "github-user", ghcrAuth.Username)
|
||||
t.Run("multiple schedules returns earliest", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First schedule: every hour at minute 0
|
||||
hourly, err := cron.ParseStandard("0 * * * *")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second schedule: every 30 minutes
|
||||
halfHourly, err := cron.ParseStandard("*/30 * * * *")
|
||||
require.NoError(t, err)
|
||||
|
||||
dag := &core.DAG{
|
||||
Schedule: []core.Schedule{
|
||||
{Expression: "0 * * * *", Parsed: hourly},
|
||||
{Expression: "*/30 * * * *", Parsed: halfHourly},
|
||||
},
|
||||
}
|
||||
|
||||
now := time.Date(2023, 10, 1, 12, 15, 0, 0, time.UTC)
|
||||
result := dag.NextRun(now)
|
||||
|
||||
// Should be at 12:30 (every 30 min) before 13:00 (hourly)
|
||||
expected := time.Date(2023, 10, 1, 12, 30, 0, 0, time.UTC)
|
||||
assert.Equal(t, expected, result)
|
||||
})
|
||||
|
||||
t.Run("nil Parsed in schedule is skipped", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parsed, err := cron.ParseStandard("0 * * * *")
|
||||
require.NoError(t, err)
|
||||
|
||||
dag := &core.DAG{
|
||||
Schedule: []core.Schedule{
|
||||
{Expression: "invalid", Parsed: nil}, // nil Parsed should be skipped
|
||||
{Expression: "0 * * * *", Parsed: parsed}, // valid
|
||||
},
|
||||
}
|
||||
|
||||
now := time.Date(2023, 10, 1, 12, 30, 0, 0, time.UTC)
|
||||
result := dag.NextRun(now)
|
||||
|
||||
expected := time.Date(2023, 10, 1, 13, 0, 0, 0, time.UTC)
|
||||
assert.Equal(t, expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDAG_GetName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("name set returns name", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Name: "my-dag", Location: "/path/to/other.yaml"}
|
||||
assert.Equal(t, "my-dag", dag.GetName())
|
||||
})
|
||||
|
||||
t.Run("name empty returns filename from location", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Name: "", Location: "/path/to/mydag.yaml"}
|
||||
assert.Equal(t, "mydag", dag.GetName())
|
||||
})
|
||||
|
||||
t.Run("name empty and location empty returns empty", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &core.DAG{Name: "", Location: ""}
|
||||
assert.Equal(t, "", dag.GetName())
|
||||
})
|
||||
}
|
||||
|
||||
460
internal/core/errors_test.go
Normal file
460
internal/core/errors_test.go
Normal file
@ -0,0 +1,460 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestErrorList_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errList ErrorList
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty list returns empty string",
|
||||
errList: ErrorList{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single error returns error message",
|
||||
errList: ErrorList{errors.New("first error")},
|
||||
expected: "first error",
|
||||
},
|
||||
{
|
||||
name: "multiple errors joined with semicolon",
|
||||
errList: ErrorList{errors.New("first"), errors.New("second"), errors.New("third")},
|
||||
expected: "first; second; third",
|
||||
},
|
||||
{
|
||||
name: "two errors joined with semicolon",
|
||||
errList: ErrorList{errors.New("error1"), errors.New("error2")},
|
||||
expected: "error1; error2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := tt.errList.Error()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorList_ToStringList(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errList ErrorList
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "empty list returns empty slice",
|
||||
errList: ErrorList{},
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "single error returns slice with one string",
|
||||
errList: ErrorList{errors.New("single error")},
|
||||
expected: []string{"single error"},
|
||||
},
|
||||
{
|
||||
name: "multiple errors returns slice with all strings",
|
||||
errList: ErrorList{errors.New("first"), errors.New("second"), errors.New("third")},
|
||||
expected: []string{"first", "second", "third"},
|
||||
},
|
||||
{
|
||||
name: "preserves order of errors",
|
||||
errList: ErrorList{errors.New("a"), errors.New("b"), errors.New("c")},
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := tt.errList.ToStringList()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorList_Unwrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("empty list returns nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
errList := ErrorList{}
|
||||
result := errList.Unwrap()
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("single error returns slice with one error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := errors.New("test error")
|
||||
errList := ErrorList{err}
|
||||
result := errList.Unwrap()
|
||||
require.Len(t, result, 1)
|
||||
assert.Equal(t, err, result[0])
|
||||
})
|
||||
|
||||
t.Run("multiple errors returns slice with all errors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err1 := errors.New("error 1")
|
||||
err2 := errors.New("error 2")
|
||||
err3 := errors.New("error 3")
|
||||
errList := ErrorList{err1, err2, err3}
|
||||
result := errList.Unwrap()
|
||||
require.Len(t, result, 3)
|
||||
assert.Equal(t, err1, result[0])
|
||||
assert.Equal(t, err2, result[1])
|
||||
assert.Equal(t, err3, result[2])
|
||||
})
|
||||
|
||||
t.Run("errors.Is works with unwrapped errors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
targetErr := ErrStepNameRequired
|
||||
errList := ErrorList{errors.New("other error"), targetErr}
|
||||
|
||||
// errors.Is should find the target error in the list
|
||||
assert.True(t, errors.Is(errList, targetErr))
|
||||
})
|
||||
|
||||
t.Run("errors.Is returns false for non-existent error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
errList := ErrorList{errors.New("other error")}
|
||||
|
||||
// errors.Is should not find ErrStepNameRequired
|
||||
assert.False(t, errors.Is(errList, ErrStepNameRequired))
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidationError_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
value any
|
||||
err error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil value formats without value",
|
||||
field: "testField",
|
||||
value: nil,
|
||||
err: errors.New("test error"),
|
||||
expected: "field 'testField': test error",
|
||||
},
|
||||
{
|
||||
name: "string value formats with value",
|
||||
field: "name",
|
||||
value: "my-dag",
|
||||
err: errors.New("name too long"),
|
||||
expected: "field 'name': name too long (value: my-dag)",
|
||||
},
|
||||
{
|
||||
name: "int value formats with value",
|
||||
field: "maxRetries",
|
||||
value: 5,
|
||||
err: errors.New("value out of range"),
|
||||
expected: "field 'maxRetries': value out of range (value: 5)",
|
||||
},
|
||||
{
|
||||
name: "empty field name",
|
||||
field: "",
|
||||
value: "test",
|
||||
err: errors.New("invalid"),
|
||||
expected: "field '': invalid (value: test)",
|
||||
},
|
||||
{
|
||||
name: "struct value uses %+v format",
|
||||
field: "config",
|
||||
value: struct{ Name string }{Name: "test"},
|
||||
err: errors.New("invalid config"),
|
||||
expected: "field 'config': invalid config (value: {Name:test})",
|
||||
},
|
||||
{
|
||||
name: "slice value",
|
||||
field: "tags",
|
||||
value: []string{"a", "b"},
|
||||
err: errors.New("invalid tags"),
|
||||
expected: "field 'tags': invalid tags (value: [a b])",
|
||||
},
|
||||
{
|
||||
name: "bool value",
|
||||
field: "enabled",
|
||||
value: true,
|
||||
err: errors.New("cannot enable"),
|
||||
expected: "field 'enabled': cannot enable (value: true)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ve := &ValidationError{
|
||||
Field: tt.field,
|
||||
Value: tt.value,
|
||||
Err: tt.err,
|
||||
}
|
||||
result := ve.Error()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationError_Unwrap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns underlying error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
underlyingErr := errors.New("underlying error")
|
||||
ve := &ValidationError{
|
||||
Field: "test",
|
||||
Value: nil,
|
||||
Err: underlyingErr,
|
||||
}
|
||||
result := ve.Unwrap()
|
||||
assert.Equal(t, underlyingErr, result)
|
||||
})
|
||||
|
||||
t.Run("works with errors.Is", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ve := &ValidationError{
|
||||
Field: "name",
|
||||
Value: "test",
|
||||
Err: ErrStepNameTooLong,
|
||||
}
|
||||
assert.True(t, errors.Is(ve, ErrStepNameTooLong))
|
||||
})
|
||||
|
||||
t.Run("works with errors.As", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ve := &ValidationError{
|
||||
Field: "test",
|
||||
Value: "value",
|
||||
Err: errors.New("test"),
|
||||
}
|
||||
|
||||
var targetErr *ValidationError
|
||||
assert.True(t, errors.As(ve, &targetErr))
|
||||
assert.Equal(t, "test", targetErr.Field)
|
||||
})
|
||||
|
||||
t.Run("nil underlying error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ve := &ValidationError{
|
||||
Field: "test",
|
||||
Value: nil,
|
||||
Err: nil,
|
||||
}
|
||||
result := ve.Unwrap()
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewValidationError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
value any
|
||||
err error
|
||||
expectedField string
|
||||
expectedValue any
|
||||
}{
|
||||
{
|
||||
name: "creates validation error with all fields",
|
||||
field: "steps",
|
||||
value: "step1",
|
||||
err: ErrStepNameDuplicate,
|
||||
expectedField: "steps",
|
||||
expectedValue: "step1",
|
||||
},
|
||||
{
|
||||
name: "creates validation error with nil value",
|
||||
field: "name",
|
||||
value: nil,
|
||||
err: ErrStepNameRequired,
|
||||
expectedField: "name",
|
||||
expectedValue: nil,
|
||||
},
|
||||
{
|
||||
name: "creates validation error with empty field",
|
||||
field: "",
|
||||
value: 123,
|
||||
err: errors.New("test"),
|
||||
expectedField: "",
|
||||
expectedValue: 123,
|
||||
},
|
||||
{
|
||||
name: "creates validation error with nil error",
|
||||
field: "test",
|
||||
value: "value",
|
||||
err: nil,
|
||||
expectedField: "test",
|
||||
expectedValue: "value",
|
||||
},
|
||||
{
|
||||
name: "creates validation error with complex value",
|
||||
field: "config",
|
||||
value: map[string]int{"a": 1, "b": 2},
|
||||
err: errors.New("invalid"),
|
||||
expectedField: "config",
|
||||
expectedValue: map[string]int{"a": 1, "b": 2},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := NewValidationError(tt.field, tt.value, tt.err)
|
||||
|
||||
// Assert it returns a ValidationError
|
||||
var ve *ValidationError
|
||||
require.True(t, errors.As(err, &ve))
|
||||
|
||||
assert.Equal(t, tt.expectedField, ve.Field)
|
||||
assert.Equal(t, tt.expectedValue, ve.Value)
|
||||
assert.Equal(t, tt.err, ve.Err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorConstants(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test that all error constants are defined and have meaningful messages
|
||||
errorConstants := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{"ErrNameTooLong", ErrNameTooLong},
|
||||
{"ErrNameInvalidChars", ErrNameInvalidChars},
|
||||
{"ErrInvalidSchedule", ErrInvalidSchedule},
|
||||
{"ErrScheduleMustBeStringOrArray", ErrScheduleMustBeStringOrArray},
|
||||
{"ErrInvalidScheduleType", ErrInvalidScheduleType},
|
||||
{"ErrInvalidKeyType", ErrInvalidKeyType},
|
||||
{"ErrExecutorConfigMustBeString", ErrExecutorConfigMustBeString},
|
||||
{"ErrDuplicateFunction", ErrDuplicateFunction},
|
||||
{"ErrFuncParamsMismatch", ErrFuncParamsMismatch},
|
||||
{"ErrInvalidStepData", ErrInvalidStepData},
|
||||
{"ErrStepNameRequired", ErrStepNameRequired},
|
||||
{"ErrStepNameDuplicate", ErrStepNameDuplicate},
|
||||
{"ErrStepNameTooLong", ErrStepNameTooLong},
|
||||
{"ErrStepCommandIsRequired", ErrStepCommandIsRequired},
|
||||
{"ErrStepCommandIsEmpty", ErrStepCommandIsEmpty},
|
||||
{"ErrStepCommandMustBeArrayOrString", ErrStepCommandMustBeArrayOrString},
|
||||
{"ErrInvalidParamValue", ErrInvalidParamValue},
|
||||
{"ErrCallFunctionNotFound", ErrCallFunctionNotFound},
|
||||
{"ErrNumberOfParamsMismatch", ErrNumberOfParamsMismatch},
|
||||
{"ErrRequiredParameterNotFound", ErrRequiredParameterNotFound},
|
||||
{"ErrScheduleKeyMustBeString", ErrScheduleKeyMustBeString},
|
||||
{"ErrInvalidSignal", ErrInvalidSignal},
|
||||
{"ErrInvalidEnvValue", ErrInvalidEnvValue},
|
||||
{"ErrArgsMustBeConvertibleToIntOrString", ErrArgsMustBeConvertibleToIntOrString},
|
||||
{"ErrExecutorTypeMustBeString", ErrExecutorTypeMustBeString},
|
||||
{"ErrExecutorConfigValueMustBeMap", ErrExecutorConfigValueMustBeMap},
|
||||
{"ErrExecutorHasInvalidKey", ErrExecutorHasInvalidKey},
|
||||
{"ErrExecutorConfigMustBeStringOrMap", ErrExecutorConfigMustBeStringOrMap},
|
||||
{"ErrDotEnvMustBeStringOrArray", ErrDotEnvMustBeStringOrArray},
|
||||
{"ErrPreconditionMustBeArrayOrString", ErrPreconditionMustBeArrayOrString},
|
||||
{"ErrPreconditionValueMustBeString", ErrPreconditionValueMustBeString},
|
||||
{"ErrPreconditionHasInvalidKey", ErrPreconditionHasInvalidKey},
|
||||
{"ErrContinueOnOutputMustBeStringOrArray", ErrContinueOnOutputMustBeStringOrArray},
|
||||
{"ErrContinueOnExitCodeMustBeIntOrArray", ErrContinueOnExitCodeMustBeIntOrArray},
|
||||
{"ErrDependsMustBeStringOrArray", ErrDependsMustBeStringOrArray},
|
||||
{"ErrStepsMustBeArrayOrMap", ErrStepsMustBeArrayOrMap},
|
||||
}
|
||||
|
||||
for _, ec := range errorConstants {
|
||||
t.Run(ec.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Error should not be nil
|
||||
require.NotNil(t, ec.err, "error constant %s should not be nil", ec.name)
|
||||
|
||||
// Error message should not be empty
|
||||
msg := ec.err.Error()
|
||||
assert.NotEmpty(t, msg, "error constant %s should have a non-empty message", ec.name)
|
||||
|
||||
// Error message should be meaningful (at least 5 characters)
|
||||
assert.GreaterOrEqual(t, len(msg), 5,
|
||||
"error constant %s message should be meaningful (at least 5 chars): %q", ec.name, msg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorList_ImplementsErrorInterface(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Ensure ErrorList implements the error interface
|
||||
var _ error = ErrorList{}
|
||||
var _ error = &ErrorList{}
|
||||
|
||||
// Test that it can be used as an error
|
||||
errList := ErrorList{errors.New("test")}
|
||||
var err error = errList
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "test", err.Error())
|
||||
}
|
||||
|
||||
func TestValidationError_ImplementsErrorInterface(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Ensure ValidationError implements the error interface
|
||||
var _ error = &ValidationError{}
|
||||
|
||||
// Test that it can be used as an error
|
||||
ve := &ValidationError{Field: "test", Err: errors.New("error")}
|
||||
var err error = ve
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), "test")
|
||||
}
|
||||
|
||||
func TestErrorList_WithWrappedErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("contains wrapped validation error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ve := NewValidationError("field", "value", ErrStepNameRequired)
|
||||
errList := ErrorList{ve}
|
||||
|
||||
// Should be able to find the wrapped error
|
||||
assert.True(t, errors.Is(errList, ErrStepNameRequired))
|
||||
})
|
||||
|
||||
t.Run("contains multiple wrapped errors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ve1 := NewValidationError("field1", nil, ErrStepNameRequired)
|
||||
ve2 := NewValidationError("field2", nil, ErrStepNameTooLong)
|
||||
errList := ErrorList{ve1, ve2}
|
||||
|
||||
// Should find both wrapped errors
|
||||
assert.True(t, errors.Is(errList, ErrStepNameRequired))
|
||||
assert.True(t, errors.Is(errList, ErrStepNameTooLong))
|
||||
})
|
||||
|
||||
t.Run("fmt.Errorf wrapped errors work", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wrapped := fmt.Errorf("context: %w", ErrInvalidSchedule)
|
||||
errList := ErrorList{wrapped}
|
||||
|
||||
// Should find the wrapped error
|
||||
assert.True(t, errors.Is(errList, ErrInvalidSchedule))
|
||||
})
|
||||
}
|
||||
@ -33,11 +33,12 @@ type DAGRunStore interface {
|
||||
FindAttempt(ctx context.Context, dagRun DAGRunRef) (DAGRunAttempt, error)
|
||||
// FindSubAttempt finds a sub dag-run record by dag-run ID.
|
||||
FindSubAttempt(ctx context.Context, dagRun DAGRunRef, subDAGRunID string) (DAGRunAttempt, error)
|
||||
// RemoveOldDAGRuns delete dag-run records older than retentionDays
|
||||
// RemoveOldDAGRuns deletes dag-run records older than retentionDays.
|
||||
// If retentionDays is negative, it won't delete any records.
|
||||
// If retentionDays is zero, it will delete all records for the DAG name.
|
||||
// But it will not delete the records with non-final statuses (e.g., running, queued).
|
||||
RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int) error
|
||||
// Returns a list of dag-run IDs that were removed (or would be removed in dry-run mode).
|
||||
RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int, opts ...RemoveOldDAGRunsOption) ([]string, error)
|
||||
// RenameDAGRuns renames all run data from oldName to newName
|
||||
// The name means the DAG name, renaming it will allow user to manage those runs
|
||||
// with the new DAG name.
|
||||
@ -102,6 +103,22 @@ func WithDAGRunID(dagRunID string) ListDAGRunStatusesOption {
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveOldDAGRunsOptions contains options for removing old dag-runs
|
||||
type RemoveOldDAGRunsOptions struct {
|
||||
// DryRun if true, only returns the paths that would be removed without actually deleting
|
||||
DryRun bool
|
||||
}
|
||||
|
||||
// RemoveOldDAGRunsOption is a functional option for configuring RemoveOldDAGRunsOptions
|
||||
type RemoveOldDAGRunsOption func(*RemoveOldDAGRunsOptions)
|
||||
|
||||
// WithDryRun sets the dry-run mode for removing old dag-runs
|
||||
func WithDryRun() RemoveOldDAGRunsOption {
|
||||
return func(o *RemoveOldDAGRunsOptions) {
|
||||
o.DryRun = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewDAGRunAttemptOptions contains options for creating a new run record
|
||||
type NewDAGRunAttemptOptions struct {
|
||||
// RootDAGRun is the root dag-run reference for this attempt.
|
||||
@ -133,6 +150,12 @@ type DAGRunAttempt interface {
|
||||
Hide(ctx context.Context) error
|
||||
// Hidden returns true if the attempt is hidden from normal operations.
|
||||
Hidden() bool
|
||||
// WriteOutputs writes the collected step outputs for the dag-run.
|
||||
// Does nothing if outputs is nil or has no output entries.
|
||||
WriteOutputs(ctx context.Context, outputs *DAGRunOutputs) error
|
||||
// ReadOutputs reads the collected step outputs for the dag-run.
|
||||
// Returns nil if no outputs file exists or if the file is in v1 format.
|
||||
ReadOutputs(ctx context.Context) (*DAGRunOutputs, error)
|
||||
}
|
||||
|
||||
// Errors for RunRef parsing
|
||||
|
||||
@ -67,9 +67,12 @@ func (m *mockDAGRunStore) FindSubAttempt(ctx context.Context, dagRun execution.D
|
||||
return args.Get(0).(execution.DAGRunAttempt), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockDAGRunStore) RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int) error {
|
||||
args := m.Called(ctx, name, retentionDays)
|
||||
return args.Error(0)
|
||||
func (m *mockDAGRunStore) RemoveOldDAGRuns(ctx context.Context, name string, retentionDays int, opts ...execution.RemoveOldDAGRunsOption) ([]string, error) {
|
||||
args := m.Called(ctx, name, retentionDays, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]string), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockDAGRunStore) RenameDAGRuns(ctx context.Context, oldName, newName string) error {
|
||||
@ -145,6 +148,19 @@ func (m *mockDAGRunAttempt) Hidden() bool {
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *mockDAGRunAttempt) WriteOutputs(ctx context.Context, outputs *execution.DAGRunOutputs) error {
|
||||
args := m.Called(ctx, outputs)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockDAGRunAttempt) ReadOutputs(ctx context.Context) (*execution.DAGRunOutputs, error) {
|
||||
args := m.Called(ctx)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*execution.DAGRunOutputs), args.Error(1)
|
||||
}
|
||||
|
||||
// Tests
|
||||
|
||||
func TestListDAGRunStatusesOptions(t *testing.T) {
|
||||
@ -242,10 +258,11 @@ func TestDAGRunStoreInterface(t *testing.T) {
|
||||
assert.Equal(t, mockAttempt, childFound)
|
||||
|
||||
// Test RemoveOldDAGRuns
|
||||
store.On("RemoveOldDAGRuns", ctx, "test-dag", 30).Return(nil)
|
||||
store.On("RemoveOldDAGRuns", ctx, "test-dag", 30, mock.Anything).Return([]string{"run-1", "run-2"}, nil)
|
||||
|
||||
err = store.RemoveOldDAGRuns(ctx, "test-dag", 30)
|
||||
removedIDs, err := store.RemoveOldDAGRuns(ctx, "test-dag", 30)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"run-1", "run-2"}, removedIDs)
|
||||
|
||||
// Test RenameDAGRuns
|
||||
store.On("RenameDAGRuns", ctx, "old-name", "new-name").Return(nil)
|
||||
@ -383,19 +400,22 @@ func TestRemoveOldDAGRunsEdgeCases(t *testing.T) {
|
||||
store := &mockDAGRunStore{}
|
||||
|
||||
// Test with negative retention days (should not delete anything)
|
||||
store.On("RemoveOldDAGRuns", ctx, "test-dag", -1).Return(nil)
|
||||
err := store.RemoveOldDAGRuns(ctx, "test-dag", -1)
|
||||
store.On("RemoveOldDAGRuns", ctx, "test-dag", -1, mock.Anything).Return([]string(nil), nil)
|
||||
removedIDs, err := store.RemoveOldDAGRuns(ctx, "test-dag", -1)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, removedIDs)
|
||||
|
||||
// Test with zero retention days (should delete all except non-final statuses)
|
||||
store.On("RemoveOldDAGRuns", ctx, "test-dag", 0).Return(nil)
|
||||
err = store.RemoveOldDAGRuns(ctx, "test-dag", 0)
|
||||
store.On("RemoveOldDAGRuns", ctx, "test-dag", 0, mock.Anything).Return([]string{"run-1", "run-2"}, nil)
|
||||
removedIDs, err = store.RemoveOldDAGRuns(ctx, "test-dag", 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"run-1", "run-2"}, removedIDs)
|
||||
|
||||
// Test with positive retention days
|
||||
store.On("RemoveOldDAGRuns", ctx, "test-dag", 30).Return(nil)
|
||||
err = store.RemoveOldDAGRuns(ctx, "test-dag", 30)
|
||||
store.On("RemoveOldDAGRuns", ctx, "test-dag", 30, mock.Anything).Return([]string{"run-old"}, nil)
|
||||
removedIDs, err = store.RemoveOldDAGRuns(ctx, "test-dag", 30)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"run-old"}, removedIDs)
|
||||
|
||||
store.AssertExpectations(t)
|
||||
}
|
||||
|
||||
17
internal/core/execution/outputs.go
Normal file
17
internal/core/execution/outputs.go
Normal file
@ -0,0 +1,17 @@
|
||||
package execution
|
||||
|
||||
// DAGRunOutputs represents the full outputs file structure with metadata.
|
||||
type DAGRunOutputs struct {
|
||||
Metadata OutputsMetadata `json:"metadata"`
|
||||
Outputs map[string]string `json:"outputs"`
|
||||
}
|
||||
|
||||
// OutputsMetadata contains execution context for the outputs.
|
||||
type OutputsMetadata struct {
|
||||
DAGName string `json:"dagName"`
|
||||
DAGRunID string `json:"dagRunId"`
|
||||
AttemptID string `json:"attemptId"`
|
||||
Status string `json:"status"`
|
||||
CompletedAt string `json:"completedAt"`
|
||||
Params string `json:"params,omitempty"` // JSON-serialized parameters
|
||||
}
|
||||
@ -51,6 +51,7 @@ type DAGRunStatus struct {
|
||||
StartedAt string `json:"startedAt,omitempty"`
|
||||
FinishedAt string `json:"finishedAt,omitempty"`
|
||||
Log string `json:"log,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Params string `json:"params,omitempty"`
|
||||
ParamsList []string `json:"paramsList,omitempty"`
|
||||
Preconditions []*core.Condition `json:"preconditions,omitempty"`
|
||||
@ -64,6 +65,9 @@ func (st *DAGRunStatus) DAGRun() DAGRunRef {
|
||||
// Errors returns a slice of errors for the current status
|
||||
func (st *DAGRunStatus) Errors() []error {
|
||||
var errs []error
|
||||
if st.Error != "" {
|
||||
errs = append(errs, fmt.Errorf("%s", st.Error))
|
||||
}
|
||||
for _, node := range st.Nodes {
|
||||
if node.Error != "" {
|
||||
errs = append(errs, fmt.Errorf("node %s: %s", node.Step.Name, node.Error))
|
||||
|
||||
@ -1,25 +0,0 @@
|
||||
package core
|
||||
|
||||
import "regexp"
|
||||
|
||||
// DAGNameMaxLen defines the maximum allowed length for a DAG name.
|
||||
const DAGNameMaxLen = 40
|
||||
|
||||
// dagNameRegex matches valid DAG names: alphanumeric, underscore, dash, dot.
|
||||
var dagNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
|
||||
|
||||
// ValidateDAGName validates a DAG name according to shared rules.
|
||||
// - Empty name is allowed (caller may provide one via context or filename).
|
||||
// - Non-empty name must satisfy length and allowed character constraints.
|
||||
func ValidateDAGName(name string) error {
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
if len(name) > DAGNameMaxLen {
|
||||
return ErrNameTooLong
|
||||
}
|
||||
if !dagNameRegex.MatchString(name) {
|
||||
return ErrNameInvalidChars
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -118,7 +118,7 @@ steps:
|
||||
parallel: ${ITEMS}
|
||||
`,
|
||||
wantErr: true,
|
||||
wantErrMsg: "parallel execution is only supported for child-DAGs",
|
||||
wantErrMsg: "cannot use sub-DAG field",
|
||||
},
|
||||
{
|
||||
name: "ErrorParallelWithoutCommandOrRun",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,69 +0,0 @@
|
||||
package spec_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/spec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStepTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Positive timeout
|
||||
t.Run("PositiveTimeout", func(t *testing.T) {
|
||||
data := []byte(`
|
||||
steps:
|
||||
- name: work
|
||||
command: echo doing
|
||||
timeoutSec: 5
|
||||
`)
|
||||
dag, err := spec.LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dag.Steps, 1)
|
||||
assert.Equal(t, 5*time.Second, dag.Steps[0].Timeout)
|
||||
})
|
||||
|
||||
// Zero timeout (explicit) -> unset/zero duration
|
||||
t.Run("ZeroTimeoutExplicit", func(t *testing.T) {
|
||||
data := []byte(`
|
||||
steps:
|
||||
- name: work
|
||||
command: echo none
|
||||
timeoutSec: 0
|
||||
`)
|
||||
dag, err := spec.LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dag.Steps, 1)
|
||||
assert.Equal(t, time.Duration(0), dag.Steps[0].Timeout)
|
||||
})
|
||||
|
||||
// Zero timeout (omitted) -> also zero
|
||||
t.Run("ZeroTimeoutOmitted", func(t *testing.T) {
|
||||
data := []byte(`
|
||||
steps:
|
||||
- name: work
|
||||
command: echo omitted
|
||||
`)
|
||||
dag, err := spec.LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dag.Steps, 1)
|
||||
assert.Equal(t, time.Duration(0), dag.Steps[0].Timeout)
|
||||
})
|
||||
|
||||
// Negative timeout should fail validation
|
||||
t.Run("NegativeTimeout", func(t *testing.T) {
|
||||
data := []byte(`
|
||||
steps:
|
||||
- name: bad
|
||||
command: echo fail
|
||||
timeoutSec: -3
|
||||
`)
|
||||
_, err := spec.LoadYAML(context.Background(), data)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "timeoutSec must be >= 0")
|
||||
})
|
||||
}
|
||||
@ -1,106 +0,0 @@
|
||||
package spec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/common/cmdutil"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
)
|
||||
|
||||
// buildCommand parses the command field in the step definition.
|
||||
// Case 1: command is nil
|
||||
// Case 2: command is a string
|
||||
// Case 3: command is an array
|
||||
//
|
||||
// In case 3, the first element is the command and the rest are the arguments.
|
||||
// If the arguments are not strings, they are converted to strings.
|
||||
//
|
||||
// Example:
|
||||
// ```yaml
|
||||
// step:
|
||||
// - name: "echo hello"
|
||||
// command: "echo hello"
|
||||
//
|
||||
// ```
|
||||
// or
|
||||
// ```yaml
|
||||
// step:
|
||||
// - name: "echo hello"
|
||||
// command: ["echo", "hello"]
|
||||
//
|
||||
// ```
|
||||
// It returns an error if the command is not nil but empty.
|
||||
func buildCommand(_ StepBuildContext, def stepDef, step *core.Step) error {
|
||||
command := def.Command
|
||||
|
||||
// Case 1: command is nil
|
||||
if command == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch val := command.(type) {
|
||||
case string:
|
||||
// Case 2: command is a string
|
||||
val = strings.TrimSpace(val)
|
||||
if val == "" {
|
||||
return core.NewValidationError("command", val, ErrStepCommandIsEmpty)
|
||||
}
|
||||
|
||||
// If the value is multi-line, treat it as a script
|
||||
if strings.Contains(val, "\n") {
|
||||
step.Script = val
|
||||
return nil
|
||||
}
|
||||
|
||||
// We need to split the command into command and args.
|
||||
step.CmdWithArgs = val
|
||||
cmd, args, err := cmdutil.SplitCommand(val)
|
||||
if err != nil {
|
||||
return core.NewValidationError("command", val, fmt.Errorf("failed to parse command: %w", err))
|
||||
}
|
||||
step.Command = strings.TrimSpace(cmd)
|
||||
step.Args = args
|
||||
|
||||
case []any:
|
||||
// Case 3: command is an array
|
||||
|
||||
var command string
|
||||
var args []string
|
||||
for _, v := range val {
|
||||
val, ok := v.(string)
|
||||
if !ok {
|
||||
// If the value is not a string, convert it to a string.
|
||||
// This is useful when the value is an integer for example.
|
||||
val = fmt.Sprintf("%v", v)
|
||||
}
|
||||
val = strings.TrimSpace(val)
|
||||
if command == "" {
|
||||
command = val
|
||||
continue
|
||||
}
|
||||
args = append(args, val)
|
||||
}
|
||||
|
||||
// Setup CmdWithArgs (this will be actually used in the command execution)
|
||||
var sb strings.Builder
|
||||
for i, arg := range step.Args {
|
||||
if i > 0 {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%q", arg))
|
||||
}
|
||||
|
||||
step.Command = command
|
||||
step.Args = args
|
||||
step.CmdWithArgs = fmt.Sprintf("%s %s", step.Command, sb.String())
|
||||
step.CmdArgsSys = cmdutil.JoinCommandArgs(step.Command, step.Args)
|
||||
|
||||
default:
|
||||
// Unknown type for command field.
|
||||
return core.NewValidationError("command", val, ErrStepCommandMustBeArrayOrString)
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
1173
internal/core/spec/dag.go
Normal file
1173
internal/core/spec/dag.go
Normal file
File diff suppressed because it is too large
Load Diff
1833
internal/core/spec/dag_test.go
Normal file
1833
internal/core/spec/dag_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,288 +0,0 @@
|
||||
package spec
|
||||
|
||||
// definition is a temporary struct to hold the DAG definition.
|
||||
// This struct is used to unmarshal the YAML data.
|
||||
// The data is then converted to the DAG struct.
|
||||
type definition struct {
|
||||
// Name is the name of the DAG.
|
||||
Name string
|
||||
// Group is the group of the DAG for grouping DAGs on the UI.
|
||||
Group string
|
||||
// Description is the description of the DAG.
|
||||
Description string
|
||||
// Type is the execution type for steps (graph, chain, or agent).
|
||||
// Default is "graph" which uses dependency-based execution.
|
||||
// "chain" executes steps in the order they are defined.
|
||||
// "agent" is reserved for future agent-based execution.
|
||||
Type string
|
||||
// Shell is the default shell to use for all steps in this DAG.
|
||||
// If not specified, the system default shell is used.
|
||||
// Can be overridden at the step level.
|
||||
// Can be a string (e.g., "bash -e") or an array (e.g., ["bash", "-e"]).
|
||||
Shell any
|
||||
// WorkingDir is working directory for DAG execution
|
||||
WorkingDir string
|
||||
// Dotenv is the path to the dotenv file (string or []string).
|
||||
Dotenv any
|
||||
// Schedule is the cron schedule to run the DAG.
|
||||
Schedule any
|
||||
// SkipIfSuccessful is the flag to skip the DAG on schedule when it is
|
||||
// executed manually before the schedule.
|
||||
SkipIfSuccessful bool
|
||||
// LogFile is the file to write the log.
|
||||
LogDir string
|
||||
// Env is the environment variables setting.
|
||||
Env any
|
||||
// HandlerOn is the handler configuration.
|
||||
HandlerOn handlerOnDef
|
||||
// Steps is the list of steps to run.
|
||||
Steps any // []stepDef or map[string]stepDef
|
||||
// SMTP is the SMTP configuration.
|
||||
SMTP smtpConfigDef
|
||||
// MailOn is the mail configuration.
|
||||
MailOn *mailOnDef
|
||||
// ErrorMail is the mail configuration for error.
|
||||
ErrorMail mailConfigDef
|
||||
// InfoMail is the mail configuration for information.
|
||||
InfoMail mailConfigDef
|
||||
// TimeoutSec is the timeout in seconds to finish the DAG.
|
||||
TimeoutSec int
|
||||
// DelaySec is the delay in seconds to start the first node.
|
||||
DelaySec int
|
||||
// RestartWaitSec is the wait in seconds to when the DAG is restarted.
|
||||
RestartWaitSec int
|
||||
// HistRetentionDays is the retention days of the dag-runs history.
|
||||
HistRetentionDays *int
|
||||
// Precondition is the condition to run the DAG.
|
||||
Precondition any
|
||||
// Preconditions is the condition to run the DAG.
|
||||
Preconditions any
|
||||
// maxActiveRuns is the maximum number of concurrent dag-runs.
|
||||
MaxActiveRuns int
|
||||
// MaxActiveSteps is the maximum number of concurrent steps.
|
||||
MaxActiveSteps int
|
||||
// Params is the default parameters for the steps.
|
||||
Params any
|
||||
// MaxCleanUpTimeSec is the maximum time in seconds to clean up the DAG.
|
||||
// It is a wait time to kill the processes when it is requested to stop.
|
||||
// If the time is exceeded, the process is killed.
|
||||
MaxCleanUpTimeSec *int
|
||||
// Tags is the tags for the DAG.
|
||||
Tags any
|
||||
// Queue is the name of the queue to assign this DAG to.
|
||||
Queue string
|
||||
// MaxOutputSize is the maximum size of the output for each step.
|
||||
MaxOutputSize int
|
||||
// OTel is the OpenTelemetry configuration.
|
||||
OTel any
|
||||
// WorkerSelector specifies required worker labels for execution.
|
||||
WorkerSelector map[string]string
|
||||
// Container is the container definition for the DAG.
|
||||
Container *containerDef
|
||||
// RunConfig contains configuration for controlling user interactions during DAG runs.
|
||||
RunConfig *runConfigDef
|
||||
// RegistryAuths maps registry hostnames to authentication configs.
|
||||
// Can be either a JSON string or a map of registry to auth config.
|
||||
RegistryAuths any
|
||||
// SSH is the default SSH configuration for the DAG.
|
||||
SSH *sshDef
|
||||
// Secrets contains references to external secrets.
|
||||
Secrets []secretRefDef
|
||||
}
|
||||
|
||||
// handlerOnDef defines the steps to be executed on different events.
|
||||
type handlerOnDef struct {
|
||||
Init *stepDef // Step to execute before steps (after preconditions pass)
|
||||
Failure *stepDef // Step to execute on failure
|
||||
Success *stepDef // Step to execute on success
|
||||
Abort *stepDef // Step to execute on abort (canonical field)
|
||||
Cancel *stepDef // Step to execute on cancel (deprecated: use Abort instead)
|
||||
Exit *stepDef // Step to execute on exit
|
||||
}
|
||||
|
||||
// stepDef defines a step in the DAG.
|
||||
type stepDef struct {
|
||||
// Name is the name of the step.
|
||||
Name string `yaml:"name,omitempty"`
|
||||
// ID is the optional unique identifier for the step.
|
||||
ID string `yaml:"id,omitempty"`
|
||||
// Description is the description of the step.
|
||||
Description string `yaml:"description,omitempty"`
|
||||
// WorkingDir is the working directory of the step.
|
||||
WorkingDir string `yaml:"workingDir,omitempty"`
|
||||
// Dir is the working directory of the step.
|
||||
// Deprecated: use WorkingDir instead
|
||||
Dir string `yaml:"dir,omitempty"`
|
||||
// Executor is the executor configuration.
|
||||
Executor any `yaml:"executor,omitempty"`
|
||||
// Command is the command to run (on shell).
|
||||
Command any `yaml:"command,omitempty"`
|
||||
// Shell is the shell to run the command. Default is `$SHELL` or `sh`.
|
||||
// Can be a string (e.g., "bash -e") or an array (e.g., ["bash", "-e"]).
|
||||
Shell any `yaml:"shell,omitempty"`
|
||||
// ShellPackages is the list of packages to install.
|
||||
// This is used only when the shell is `nix-shell`.
|
||||
ShellPackages []string `yaml:"shellPackages,omitempty"`
|
||||
// Script is the script to run.
|
||||
Script string `yaml:"script,omitempty"`
|
||||
// Stdout is the file to write the stdout.
|
||||
Stdout string `yaml:"stdout,omitempty"`
|
||||
// Stderr is the file to write the stderr.
|
||||
Stderr string `yaml:"stderr,omitempty"`
|
||||
// Output is the variable name to store the output.
|
||||
Output string `yaml:"output,omitempty"`
|
||||
// Depends is the list of steps to depend on.
|
||||
Depends any `yaml:"depends,omitempty"` // string or []string
|
||||
// ContinueOn is the condition to continue on.
|
||||
// Can be a string ("skipped", "failed") or an object with detailed config.
|
||||
ContinueOn any `yaml:"continueOn,omitempty"`
|
||||
// RetryPolicy is the retry policy.
|
||||
RetryPolicy *retryPolicyDef `yaml:"retryPolicy,omitempty"`
|
||||
// RepeatPolicy is the repeat policy.
|
||||
RepeatPolicy *repeatPolicyDef `yaml:"repeatPolicy,omitempty"`
|
||||
// MailOnError is the flag to send mail on error.
|
||||
MailOnError bool `yaml:"mailOnError,omitempty"`
|
||||
// Precondition is the condition to run the step.
|
||||
Precondition any `yaml:"precondition,omitempty"`
|
||||
// Preconditions is the condition to run the step.
|
||||
Preconditions any `yaml:"preconditions,omitempty"`
|
||||
// SignalOnStop is the signal when the step is requested to stop.
|
||||
// When it is empty, the same signal as the parent process is sent.
|
||||
// It can be KILL when the process does not stop over the timeout.
|
||||
SignalOnStop *string `yaml:"signalOnStop,omitempty"`
|
||||
// Call is the name of a DAG to run as a sub dag-run.
|
||||
Call string `yaml:"call,omitempty"`
|
||||
// Run is the name of a DAG to run as a sub dag-run.
|
||||
// Deprecated: use Call instead.
|
||||
Run string `yaml:"run,omitempty"`
|
||||
// Params specifies the parameters for the sub dag-run.
|
||||
Params any `yaml:"params,omitempty"`
|
||||
// Parallel specifies parallel execution configuration.
|
||||
// Can be:
|
||||
// - Direct array reference: parallel: ${ITEMS}
|
||||
// - Static array: parallel: [item1, item2]
|
||||
// - Object configuration: parallel: {items: ${ITEMS}, maxConcurrent: 5}
|
||||
Parallel any `yaml:"parallel,omitempty"`
|
||||
// WorkerSelector specifies required worker labels for execution.
|
||||
WorkerSelector map[string]string `yaml:"workerSelector,omitempty"`
|
||||
// Env specifies the environment variables for the step.
|
||||
Env any `yaml:"env,omitempty"`
|
||||
// TimeoutSec specifies the maximum runtime for the step in seconds.
|
||||
TimeoutSec int `yaml:"timeoutSec,omitempty"`
|
||||
}
|
||||
|
||||
// repeatPolicyDef defines the repeat policy for a step.
|
||||
type repeatPolicyDef struct {
|
||||
Repeat any `yaml:"repeat,omitempty"` // Flag to indicate if the step should be repeated, can be bool (legacy) or string ("while" or "until")
|
||||
IntervalSec int `yaml:"intervalSec,omitempty"` // Interval in seconds to wait before repeating the step
|
||||
Limit int `yaml:"limit,omitempty"` // Maximum number of times to repeat the step
|
||||
Condition string `yaml:"condition,omitempty"` // Condition to check before repeating
|
||||
Expected string `yaml:"expected,omitempty"` // Expected output to match before repeating
|
||||
ExitCode []int `yaml:"exitCode,omitempty"` // List of exit codes to consider for repeating the step
|
||||
Backoff any `yaml:"backoff,omitempty"` // Accepts bool or float
|
||||
MaxIntervalSec int `yaml:"maxIntervalSec,omitempty"` // Maximum interval in seconds
|
||||
}
|
||||
|
||||
// retryPolicyDef defines the retry policy for a step.
|
||||
type retryPolicyDef struct {
|
||||
Limit any `yaml:"limit,omitempty"`
|
||||
IntervalSec any `yaml:"intervalSec,omitempty"`
|
||||
ExitCode []int `yaml:"exitCode,omitempty"`
|
||||
Backoff any `yaml:"backoff,omitempty"` // Accepts bool or float
|
||||
MaxIntervalSec int `yaml:"maxIntervalSec,omitempty"`
|
||||
}
|
||||
|
||||
// smtpConfigDef defines the SMTP configuration.
|
||||
type smtpConfigDef struct {
|
||||
Host string // SMTP host
|
||||
Port any // SMTP port (can be string or number)
|
||||
Username string // SMTP username
|
||||
Password string // SMTP password
|
||||
}
|
||||
|
||||
// mailConfigDef defines the mail configuration.
|
||||
type mailConfigDef struct {
|
||||
From string // Sender email address
|
||||
To any // Recipient email address(es) - can be string or []string
|
||||
Prefix string // Prefix for the email subject
|
||||
AttachLogs bool // Flag to attach logs to the email
|
||||
}
|
||||
|
||||
// mailOnDef defines the conditions to send mail.
|
||||
type mailOnDef struct {
|
||||
Failure bool // Send mail on failure
|
||||
Success bool // Send mail on success
|
||||
}
|
||||
|
||||
// containerDef defines the container configuration for the DAG.
|
||||
type containerDef struct {
|
||||
// Image is the container image to use.
|
||||
Image string `yaml:"image,omitempty"`
|
||||
// PullPolicy is the policy to pull the image (e.g., "Always", "IfNotPresent").
|
||||
PullPolicy any `yaml:"pullPolicy,omitempty"`
|
||||
// Env specifies environment variables for the container.
|
||||
Env any `yaml:"env,omitempty"` // Can be a map or struct
|
||||
// Volumes specifies the volumes to mount in the container.
|
||||
Volumes []string `yaml:"volumes,omitempty"` // Map of volume names to volume definitions
|
||||
// User is the user to run the container as.
|
||||
User string `yaml:"user,omitempty"` // User to run the container as
|
||||
// WorkingDir is the working directory inside the container.
|
||||
WorkingDir string `yaml:"workingDir,omitempty"` // Working directory inside the container
|
||||
// WorkDir is the working directory inside the container.
|
||||
// Deprecated: use WorkingDir instead
|
||||
WorkDir string `yaml:"workDir,omitempty"` // Working directory inside the container
|
||||
// Platform specifies the platform for the container (e.g., "linux/amd64").
|
||||
Platform string `yaml:"platform,omitempty"` // Platform for the container
|
||||
// Ports specifies the ports to expose from the container.
|
||||
Ports []string `yaml:"ports,omitempty"` // List of ports to expose
|
||||
// Network is the network configuration for the container.
|
||||
Network string `yaml:"network,omitempty"` // Network configuration for the container
|
||||
// KeepContainer is the flag to keep the container after the DAG run.
|
||||
KeepContainer bool `yaml:"keepContainer,omitempty"` // Keep the container after the DAG run
|
||||
// Startup determines how the DAG-level container starts up.
|
||||
Startup string `yaml:"startup,omitempty"`
|
||||
// Command used when Startup == "command".
|
||||
Command []string `yaml:"command,omitempty"`
|
||||
// WaitFor readiness condition: running|healthy
|
||||
WaitFor string `yaml:"waitFor,omitempty"`
|
||||
// LogPattern regex to wait for in container logs.
|
||||
LogPattern string `yaml:"logPattern,omitempty"`
|
||||
// RestartPolicy: no|always|unless-stopped
|
||||
RestartPolicy string `yaml:"restartPolicy,omitempty"`
|
||||
}
|
||||
|
||||
// runConfigDef defines configuration for controlling user interactions during DAG runs.
|
||||
type runConfigDef struct {
|
||||
DisableParamEdit bool `yaml:"disableParamEdit,omitempty"` // Disable parameter editing when starting DAG
|
||||
DisableRunIdEdit bool `yaml:"disableRunIdEdit,omitempty"` // Disable custom run ID specification
|
||||
}
|
||||
|
||||
// sshDef defines the SSH configuration for the DAG.
|
||||
type sshDef struct {
|
||||
// User is the SSH user.
|
||||
User string `yaml:"user,omitempty"`
|
||||
// Host is the SSH host.
|
||||
Host string `yaml:"host,omitempty"`
|
||||
// Port is the SSH port (can be string or number).
|
||||
Port any `yaml:"port,omitempty"`
|
||||
// Key is the path to the SSH private key.
|
||||
Key string `yaml:"key,omitempty"`
|
||||
// Password is the SSH password.
|
||||
Password string `yaml:"password,omitempty"`
|
||||
// StrictHostKey enables strict host key checking. Defaults to true if not specified.
|
||||
StrictHostKey *bool `yaml:"strictHostKey,omitempty"`
|
||||
// KnownHostFile is the path to the known_hosts file. Defaults to ~/.ssh/known_hosts.
|
||||
KnownHostFile string `yaml:"knownHostFile,omitempty"`
|
||||
}
|
||||
|
||||
// secretRefDef defines a reference to an external secret.
|
||||
type secretRefDef struct {
|
||||
// Name is the environment variable name (required).
|
||||
Name string `yaml:"name"`
|
||||
// Provider specifies the secret backend (required).
|
||||
Provider string `yaml:"provider"`
|
||||
// Key is the provider-specific identifier (required).
|
||||
Key string `yaml:"key"`
|
||||
// Options contains provider-specific configuration (optional).
|
||||
Options map[string]string `yaml:"options,omitempty"`
|
||||
}
|
||||
@ -24,9 +24,15 @@ var (
|
||||
ErrExecutorConfigValueMustBeMap = errors.New("executor.config value must be a map")
|
||||
ErrExecutorHasInvalidKey = errors.New("executor has invalid key")
|
||||
ErrExecutorConfigMustBeStringOrMap = errors.New("executor config must be string or map")
|
||||
ErrContainerAndExecutorConflict = errors.New("cannot use both 'container' field and 'executor' field - the 'container' field already specifies the execution method")
|
||||
ErrContainerAndScriptConflict = errors.New("cannot use 'script' field with 'container' field - use 'command' field instead")
|
||||
ErrInvalidEnvValue = errors.New("env config should be map of strings or array of key=value formatted string")
|
||||
ErrInvalidParamValue = errors.New("invalid parameter value")
|
||||
ErrStepCommandIsEmpty = errors.New("step command is empty")
|
||||
ErrStepCommandMustBeArrayOrString = errors.New("step command must be an array of strings or a string")
|
||||
ErrTimeoutSecMustBeNonNegative = errors.New("timeoutSec must be >= 0")
|
||||
ErrExecutorDoesNotSupportMultipleCmd = errors.New("executor does not support multiple commands")
|
||||
ErrSubDAGAndExecutorConflict = errors.New("cannot use sub-DAG field ('call', 'run', or 'parallel') with 'executor' field")
|
||||
ErrSubDAGAndCommandConflict = errors.New("cannot use sub-DAG field ('call', 'run', or 'parallel') with 'command' field")
|
||||
ErrSubDAGAndScriptConflict = errors.New("cannot use sub-DAG field ('call', 'run', or 'parallel') with 'script' field")
|
||||
)
|
||||
|
||||
@ -14,6 +14,7 @@ import (
|
||||
"dario.cat/mergo"
|
||||
"github.com/dagu-org/dagu/internal/common/fileutil"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
"github.com/go-viper/mapstructure/v2"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
@ -212,7 +213,7 @@ func LoadYAMLWithOpts(ctx context.Context, data []byte, opts BuildOpts) (*core.D
|
||||
return nil, core.ErrorList{err}
|
||||
}
|
||||
|
||||
return build(BuildContext{ctx: ctx, opts: opts}, def)
|
||||
return def.build(BuildContext{ctx: ctx, opts: opts})
|
||||
}
|
||||
|
||||
// LoadBaseConfig loads the global configuration from the given file.
|
||||
@ -229,7 +230,7 @@ func LoadBaseConfig(ctx BuildContext, file string) (*core.DAG, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decode the raw data into a config definition.
|
||||
// Decode the raw data into a manifest.
|
||||
def, err := decode(raw)
|
||||
if err != nil {
|
||||
return nil, core.ErrorList{err}
|
||||
@ -238,12 +239,8 @@ func LoadBaseConfig(ctx BuildContext, file string) (*core.DAG, error) {
|
||||
ctx = ctx.WithOpts(BuildOpts{
|
||||
Flags: ctx.opts.Flags,
|
||||
}).WithFile(file)
|
||||
dag, err := build(ctx, def)
|
||||
|
||||
if err != nil {
|
||||
return nil, core.ErrorList{err}
|
||||
}
|
||||
return dag, nil
|
||||
return def.build(ctx)
|
||||
}
|
||||
|
||||
// loadDAG loads the core.DAG from the given file.
|
||||
@ -255,8 +252,8 @@ func loadDAG(ctx BuildContext, nameOrPath string) (*core.DAG, error) {
|
||||
|
||||
ctx = ctx.WithFile(filePath)
|
||||
|
||||
// Load base config definition if specified
|
||||
var baseDef *definition
|
||||
// Load base manifest if specified
|
||||
var baseDef *dag
|
||||
if !ctx.opts.Has(BuildFlagOnlyMetadata) && ctx.opts.Base != "" && fileutil.FileExists(ctx.opts.Base) {
|
||||
raw, err := readYAMLFile(ctx.opts.Base)
|
||||
if err != nil {
|
||||
@ -308,7 +305,7 @@ func loadDAG(ctx BuildContext, nameOrPath string) (*core.DAG, error) {
|
||||
}
|
||||
|
||||
// loadDAGsFromFile loads all DAGs from a multi-document YAML file
|
||||
func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *definition) ([]*core.DAG, error) {
|
||||
func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *dag) ([]*core.DAG, error) {
|
||||
// Open the file
|
||||
f, err := os.Open(filePath) //nolint:gosec
|
||||
if err != nil {
|
||||
@ -348,13 +345,13 @@ func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *definition) ([
|
||||
// Update the context with the current document index
|
||||
ctx.index = docIndex
|
||||
|
||||
// Decode the document into definition
|
||||
// Decode the document into manifest
|
||||
spec, err := decode(doc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode document %d: %w", docIndex, err)
|
||||
}
|
||||
|
||||
// Build a fresh base core.DAG from base definition if provided
|
||||
// Build a fresh base core.DAG from base manifest if provided
|
||||
var dest *core.DAG
|
||||
if baseDef != nil {
|
||||
// Build a new base core.DAG for this document
|
||||
@ -363,7 +360,7 @@ func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *definition) ([
|
||||
buildCtx.opts.Parameters = ""
|
||||
buildCtx.opts.ParametersList = nil
|
||||
// Build the base core.DAG
|
||||
baseDAG, err := build(buildCtx, baseDef)
|
||||
baseDAG, err := baseDef.build(buildCtx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build base core.DAG for document %d: %w", docIndex, err)
|
||||
}
|
||||
@ -380,7 +377,7 @@ func loadDAGsFromFile(ctx BuildContext, filePath string, baseDef *definition) ([
|
||||
}
|
||||
|
||||
// Build the core.DAG from the current document
|
||||
dag, err := build(ctx, spec)
|
||||
dag, err := spec.build(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build core.DAG in document %d: %w", docIndex, err)
|
||||
}
|
||||
@ -540,19 +537,75 @@ func unmarshalData(data []byte) (map[string]any, error) {
|
||||
return cm, err
|
||||
}
|
||||
|
||||
// decode decodes the configuration map into a configDefinition.
|
||||
func decode(cm map[string]any) (*definition, error) {
|
||||
c := new(definition)
|
||||
// decode decodes the configuration map into a manifest.
|
||||
func decode(cm map[string]any) (*dag, error) {
|
||||
c := new(dag)
|
||||
md, _ := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
ErrorUnused: true,
|
||||
Result: c,
|
||||
TagName: "",
|
||||
DecodeHook: TypedUnionDecodeHook(),
|
||||
})
|
||||
err := md.Decode(cm)
|
||||
|
||||
return c, err
|
||||
}
|
||||
|
||||
// TypedUnionDecodeHook returns a decode hook that handles our typed union types.
|
||||
// It converts raw map[string]any values to the appropriate typed values.
|
||||
func TypedUnionDecodeHook() mapstructure.DecodeHookFunc {
|
||||
return func(_ reflect.Type, to reflect.Type, data any) (any, error) {
|
||||
// Handle types.ShellValue
|
||||
if to == reflect.TypeOf(types.ShellValue{}) {
|
||||
return decodeViaYAML[types.ShellValue](data)
|
||||
}
|
||||
// Handle types.StringOrArray
|
||||
if to == reflect.TypeOf(types.StringOrArray{}) {
|
||||
return decodeViaYAML[types.StringOrArray](data)
|
||||
}
|
||||
// Handle types.ScheduleValue
|
||||
if to == reflect.TypeOf(types.ScheduleValue{}) {
|
||||
return decodeViaYAML[types.ScheduleValue](data)
|
||||
}
|
||||
// Handle types.EnvValue
|
||||
if to == reflect.TypeOf(types.EnvValue{}) {
|
||||
return decodeViaYAML[types.EnvValue](data)
|
||||
}
|
||||
// Handle types.ContinueOnValue
|
||||
if to == reflect.TypeOf(types.ContinueOnValue{}) {
|
||||
return decodeViaYAML[types.ContinueOnValue](data)
|
||||
}
|
||||
// Handle types.PortValue
|
||||
if to == reflect.TypeOf(types.PortValue{}) {
|
||||
return decodeViaYAML[types.PortValue](data)
|
||||
}
|
||||
// Handle types.LogOutputValue
|
||||
if to == reflect.TypeOf(types.LogOutputValue{}) {
|
||||
return decodeViaYAML[types.LogOutputValue](data)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// decodeViaYAML converts data to YAML and unmarshals it to the target type.
|
||||
// This allows the custom UnmarshalYAML methods to be used.
|
||||
func decodeViaYAML[T any](data any) (T, error) {
|
||||
var result T
|
||||
if data == nil {
|
||||
return result, nil
|
||||
}
|
||||
// Convert the data to YAML bytes
|
||||
yamlBytes, err := yaml.Marshal(data)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("failed to marshal to YAML: %w", err)
|
||||
}
|
||||
// Unmarshal using the custom UnmarshalYAML method
|
||||
if err := yaml.Unmarshal(yamlBytes, &result); err != nil {
|
||||
return result, fmt.Errorf("failed to unmarshal from YAML: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// merge merges the source core.DAG into the destination DAG.
|
||||
func merge(dst, src *core.DAG) error {
|
||||
return mergo.Merge(dst, src, mergo.WithOverride,
|
||||
|
||||
@ -15,7 +15,7 @@ import (
|
||||
func TestLoad(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run(("WithName"), func(t *testing.T) {
|
||||
t.Run("WithName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDAG := createTempYAMLFile(t, `steps:
|
||||
@ -26,43 +26,72 @@ func TestLoad(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "testDAG", dag.Name)
|
||||
})
|
||||
t.Run("InvalidPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a non-existing file path
|
||||
testDAG := "/tmp/non_existing_file_" + t.Name() + ".yaml"
|
||||
_, err := spec.Load(context.Background(), testDAG)
|
||||
require.Error(t, err)
|
||||
})
|
||||
t.Run("UnknownField", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Error cases
|
||||
errorTests := []struct {
|
||||
name string
|
||||
content string
|
||||
useFile bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "InvalidPath",
|
||||
useFile: false,
|
||||
},
|
||||
{
|
||||
name: "UnknownField",
|
||||
content: "invalidKey: test\n",
|
||||
useFile: true,
|
||||
errContains: "has invalid keys: invalidKey",
|
||||
},
|
||||
{
|
||||
name: "InvalidYAML",
|
||||
content: "invalidyaml",
|
||||
useFile: true,
|
||||
errContains: "invalidyaml",
|
||||
},
|
||||
}
|
||||
|
||||
testDAG := createTempYAMLFile(t, `invalidKey: test
|
||||
`)
|
||||
_, err := spec.Load(context.Background(), testDAG)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "has invalid keys: invalidKey")
|
||||
})
|
||||
t.Run("InvalidYAML", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range errorTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var testDAG string
|
||||
if tt.useFile {
|
||||
testDAG = createTempYAMLFile(t, tt.content)
|
||||
} else {
|
||||
testDAG = "/tmp/non_existing_file_" + t.Name() + ".yaml"
|
||||
}
|
||||
_, err := spec.Load(context.Background(), testDAG)
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
require.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
testDAG := createTempYAMLFile(t, `invalidyaml`)
|
||||
_, err := spec.Load(context.Background(), testDAG)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalidyaml")
|
||||
})
|
||||
t.Run("MetadataOnly", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDAG := createTempYAMLFile(t, `steps:
|
||||
testDAG := createTempYAMLFile(t, `
|
||||
logDir: /var/log/dagu
|
||||
histRetentionDays: 90
|
||||
maxCleanUpTimeSec: 60
|
||||
mailOn:
|
||||
failure: true
|
||||
steps:
|
||||
- name: "1"
|
||||
command: "true"
|
||||
`)
|
||||
dag, err := spec.Load(context.Background(), testDAG, spec.OnlyMetadata())
|
||||
require.NoError(t, err)
|
||||
// Steps should not be loaded in metadata-only mode
|
||||
require.Empty(t, dag.Steps)
|
||||
// Check if the metadata is loaded correctly
|
||||
require.Len(t, dag.Steps, 0)
|
||||
// Non-metadata fields from YAML should NOT be populated in metadata-only mode
|
||||
assert.Empty(t, dag.LogDir, "LogDir should be empty in metadata-only mode")
|
||||
assert.Nil(t, dag.MailOn, "MailOn should be nil in metadata-only mode")
|
||||
// Defaults are still applied by InitializeDefaults (not from YAML)
|
||||
assert.Equal(t, 30, dag.HistRetentionDays, "HistRetentionDays should be default (30), not YAML value (90)")
|
||||
assert.Equal(t, 5*time.Second, dag.MaxCleanUpTime, "MaxCleanUpTime should be default (5s), not YAML value (60s)")
|
||||
})
|
||||
t.Run("DefaultConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -84,7 +113,8 @@ func TestLoad(t *testing.T) {
|
||||
// Step level
|
||||
require.Len(t, dag.Steps, 1)
|
||||
assert.Equal(t, "1", dag.Steps[0].Name, "1")
|
||||
assert.Equal(t, "true", dag.Steps[0].Command, "true")
|
||||
require.Len(t, dag.Steps[0].Commands, 1)
|
||||
assert.Equal(t, "true", dag.Steps[0].Commands[0].Command)
|
||||
})
|
||||
t.Run("OverrideConfig", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -190,26 +220,45 @@ steps:
|
||||
|
||||
func TestLoadYAML(t *testing.T) {
|
||||
t.Parallel()
|
||||
const testDAG = `steps:
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
wantName string
|
||||
wantCommand string
|
||||
}{
|
||||
{
|
||||
name: "ValidYAMLData",
|
||||
input: `steps:
|
||||
- name: "1"
|
||||
command: "true"
|
||||
`
|
||||
t.Run("ValidYAMLData", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
`,
|
||||
wantName: "1",
|
||||
wantCommand: "true",
|
||||
},
|
||||
{
|
||||
name: "InvalidYAMLData",
|
||||
input: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
ret, err := spec.LoadYAMLWithOpts(context.Background(), []byte(testDAG), spec.BuildOpts{})
|
||||
require.NoError(t, err)
|
||||
|
||||
step := ret.Steps[0]
|
||||
require.Equal(t, "1", step.Name)
|
||||
require.Equal(t, "true", step.Command)
|
||||
})
|
||||
t.Run("InvalidYAMLData", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := spec.LoadYAMLWithOpts(context.Background(), []byte(`invalid`), spec.BuildOpts{})
|
||||
require.Error(t, err)
|
||||
})
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ret, err := spec.LoadYAMLWithOpts(context.Background(), []byte(tt.input), spec.BuildOpts{})
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ret.Steps, 1)
|
||||
assert.Equal(t, tt.wantName, ret.Steps[0].Name)
|
||||
require.Len(t, ret.Steps[0].Commands, 1)
|
||||
assert.Equal(t, tt.wantCommand, ret.Steps[0].Commands[0].Command)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadYAMLWithNameOption(t *testing.T) {
|
||||
@ -228,7 +277,8 @@ steps:
|
||||
|
||||
step := ret.Steps[0]
|
||||
require.Equal(t, "1", step.Name)
|
||||
require.Equal(t, "true", step.Command)
|
||||
require.Len(t, step.Commands, 1)
|
||||
require.Equal(t, "true", step.Commands[0].Command)
|
||||
}
|
||||
|
||||
// createTempYAMLFile creates a temporary YAML file with the given content
|
||||
@ -295,7 +345,8 @@ steps:
|
||||
assert.Equal(t, "transform-data", transformDAG.Name)
|
||||
assert.Len(t, transformDAG.Steps, 1)
|
||||
assert.Equal(t, "transform", transformDAG.Steps[0].Name)
|
||||
assert.Equal(t, "transform.py", transformDAG.Steps[0].Command)
|
||||
require.Len(t, transformDAG.Steps[0].Commands, 1)
|
||||
assert.Equal(t, "transform.py", transformDAG.Steps[0].Commands[0].Command)
|
||||
|
||||
// Check archive-results sub DAG
|
||||
_, exists = dag.LocalDAGs["archive-results"]
|
||||
@ -304,7 +355,8 @@ steps:
|
||||
assert.Equal(t, "archive-results", archiveDAG.Name)
|
||||
assert.Len(t, archiveDAG.Steps, 1)
|
||||
assert.Equal(t, "archive", archiveDAG.Steps[0].Name)
|
||||
assert.Equal(t, "archive.sh", archiveDAG.Steps[0].Command)
|
||||
require.Len(t, archiveDAG.Steps[0].Commands, 1)
|
||||
assert.Equal(t, "archive.sh", archiveDAG.Steps[0].Commands[0].Command)
|
||||
})
|
||||
|
||||
t.Run("MultiDAGWithBaseConfig", func(t *testing.T) {
|
||||
@ -385,11 +437,15 @@ steps:
|
||||
assert.Nil(t, dag.LocalDAGs) // No sub DAGs for single DAG file
|
||||
})
|
||||
|
||||
t.Run("DuplicateSubDAGNames", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Multi-DAG file with duplicate names
|
||||
multiDAGContent := `steps:
|
||||
// MultiDAG error cases
|
||||
multiDAGErrorTests := []struct {
|
||||
name string
|
||||
content string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "DuplicateSubDAGNames",
|
||||
content: `steps:
|
||||
- name: step1
|
||||
command: echo "main"
|
||||
|
||||
@ -404,20 +460,12 @@ name: duplicate-name
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "second"
|
||||
`
|
||||
tmpFile := createTempYAMLFile(t, multiDAGContent)
|
||||
|
||||
// Load should fail due to duplicate names
|
||||
_, err := spec.Load(context.Background(), tmpFile)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "duplicate DAG name")
|
||||
})
|
||||
|
||||
t.Run("SubDAGWithoutName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Multi-DAG file where sub DAG has no name
|
||||
multiDAGContent := `steps:
|
||||
`,
|
||||
errContains: "duplicate DAG name",
|
||||
},
|
||||
{
|
||||
name: "SubDAGWithoutName",
|
||||
content: `steps:
|
||||
- name: step1
|
||||
command: echo "main"
|
||||
|
||||
@ -425,14 +473,20 @@ steps:
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "unnamed"
|
||||
`
|
||||
tmpFile := createTempYAMLFile(t, multiDAGContent)
|
||||
`,
|
||||
errContains: "must have a name",
|
||||
},
|
||||
}
|
||||
|
||||
// Load should fail because sub DAG has no name
|
||||
_, err := spec.Load(context.Background(), tmpFile)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must have a name")
|
||||
})
|
||||
for _, tt := range multiDAGErrorTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmpFile := createTempYAMLFile(t, tt.content)
|
||||
_, err := spec.Load(context.Background(), tmpFile)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("EmptyDocumentSeparator", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -615,3 +669,147 @@ steps:
|
||||
assert.Equal(t, explicitDir, dag.WorkingDir)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadWithLoaderOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("WithDAGsDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a DAGs directory
|
||||
dagsDir := t.TempDir()
|
||||
|
||||
// Create a sub-DAG file
|
||||
subDAGPath := filepath.Join(dagsDir, "subdag.yaml")
|
||||
require.NoError(t, os.WriteFile(subDAGPath, []byte(`
|
||||
steps:
|
||||
- name: sub-step
|
||||
command: echo sub
|
||||
`), 0644))
|
||||
|
||||
// Create main DAG that calls the sub-DAG
|
||||
mainDAG := createTempYAMLFile(t, `
|
||||
steps:
|
||||
- name: main-step
|
||||
command: echo main
|
||||
`)
|
||||
// Load with WithDAGsDir
|
||||
dag, err := spec.Load(context.Background(), mainDAG, spec.WithDAGsDir(dagsDir))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dag)
|
||||
})
|
||||
|
||||
t.Run("WithAllowBuildErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDAG := createTempYAMLFile(t, `
|
||||
steps:
|
||||
- name: test
|
||||
command: echo test
|
||||
depends:
|
||||
- nonexistent-step
|
||||
`)
|
||||
// Without AllowBuildErrors, this would fail
|
||||
_, err := spec.Load(context.Background(), testDAG)
|
||||
require.Error(t, err)
|
||||
|
||||
// With AllowBuildErrors, it should succeed but capture errors
|
||||
dag, err := spec.Load(context.Background(), testDAG, spec.WithAllowBuildErrors())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dag)
|
||||
assert.NotEmpty(t, dag.BuildErrors)
|
||||
})
|
||||
|
||||
t.Run("SkipSchemaValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDAG := createTempYAMLFile(t, `
|
||||
params:
|
||||
schema: "nonexistent-schema.json"
|
||||
values:
|
||||
foo: bar
|
||||
steps:
|
||||
- name: test
|
||||
command: echo test
|
||||
`)
|
||||
// Without SkipSchemaValidation, this would fail due to missing schema
|
||||
_, err := spec.Load(context.Background(), testDAG)
|
||||
require.Error(t, err)
|
||||
|
||||
// With SkipSchemaValidation, it should succeed
|
||||
dag, err := spec.Load(context.Background(), testDAG, spec.SkipSchemaValidation())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dag)
|
||||
})
|
||||
|
||||
t.Run("WithSkipBaseHandlers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create base config with handlers
|
||||
baseDir := t.TempDir()
|
||||
baseConfig := filepath.Join(baseDir, "base.yaml")
|
||||
require.NoError(t, os.WriteFile(baseConfig, []byte(`
|
||||
handlerOn:
|
||||
success:
|
||||
command: echo base-success
|
||||
`), 0644))
|
||||
|
||||
testDAG := createTempYAMLFile(t, `
|
||||
steps:
|
||||
- name: test
|
||||
command: echo test
|
||||
`)
|
||||
// Load with base config but skip base handlers
|
||||
dag, err := spec.Load(context.Background(), testDAG,
|
||||
spec.WithBaseConfig(baseConfig),
|
||||
spec.WithSkipBaseHandlers())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dag)
|
||||
// The base success handler should not be present
|
||||
assert.Nil(t, dag.HandlerOn.Success)
|
||||
})
|
||||
|
||||
t.Run("WithParamsAsList", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testDAG := createTempYAMLFile(t, `
|
||||
params: KEY1 KEY2
|
||||
steps:
|
||||
- name: test
|
||||
command: echo $KEY1 $KEY2
|
||||
`)
|
||||
// Load with params as list
|
||||
dag, err := spec.Load(context.Background(), testDAG,
|
||||
spec.WithParams([]string{"KEY1=value1", "KEY2=value2"}))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dag)
|
||||
|
||||
// Check that params were applied
|
||||
found := 0
|
||||
for _, p := range dag.Params {
|
||||
if p == "KEY1=value1" || p == "KEY2=value2" {
|
||||
found++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, found, "Both params should be applied")
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoadWithoutEval tests the WithoutEval loader option
|
||||
// This test cannot be parallel because it uses t.Setenv
|
||||
func TestLoadWithoutEval(t *testing.T) {
|
||||
t.Setenv("TEST_VAR", "should-not-expand")
|
||||
|
||||
testDAG := createTempYAMLFile(t, `
|
||||
env:
|
||||
- MY_VAR: "${TEST_VAR}"
|
||||
steps:
|
||||
- name: test
|
||||
command: echo test
|
||||
`)
|
||||
dag, err := spec.Load(context.Background(), testDAG, spec.WithoutEval())
|
||||
require.NoError(t, err)
|
||||
|
||||
// With NoEval, environment variables should not be expanded
|
||||
assert.Contains(t, dag.Env, "MY_VAR=${TEST_VAR}")
|
||||
}
|
||||
|
||||
@ -3,88 +3,17 @@ package spec
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/common/cmdutil"
|
||||
"github.com/dagu-org/dagu/internal/common/fileutil"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
)
|
||||
|
||||
// buildParams builds the parameters for the DAG.
|
||||
func buildParams(ctx BuildContext, spec *definition, dag *core.DAG) error {
|
||||
var (
|
||||
paramPairs []paramPair
|
||||
envs []string
|
||||
)
|
||||
|
||||
if err := parseParams(ctx, spec.Params, ¶mPairs, &envs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create default parameters string in the form of "key=value key=value ..."
|
||||
var paramsToJoin []string
|
||||
for _, paramPair := range paramPairs {
|
||||
paramsToJoin = append(paramsToJoin, paramPair.Escaped())
|
||||
}
|
||||
dag.DefaultParams = strings.Join(paramsToJoin, " ")
|
||||
|
||||
if ctx.opts.Parameters != "" {
|
||||
// Parse the parameters from the command line and override the default parameters
|
||||
var (
|
||||
overridePairs []paramPair
|
||||
overrideEnvs []string
|
||||
)
|
||||
if err := parseParams(ctx, ctx.opts.Parameters, &overridePairs, &overrideEnvs); err != nil {
|
||||
return err
|
||||
}
|
||||
// Override the default parameters with the command line parameters
|
||||
overrideParams(¶mPairs, overridePairs)
|
||||
}
|
||||
|
||||
if len(ctx.opts.ParametersList) > 0 {
|
||||
var (
|
||||
overridePairs []paramPair
|
||||
overrideEnvs []string
|
||||
)
|
||||
if err := parseParams(ctx, ctx.opts.ParametersList, &overridePairs, &overrideEnvs); err != nil {
|
||||
return err
|
||||
}
|
||||
// Override the default parameters with the command line parameters
|
||||
overrideParams(¶mPairs, overridePairs)
|
||||
}
|
||||
|
||||
// Validate the parameters against a resolved schema (if declared)
|
||||
if !ctx.opts.Has(BuildFlagSkipSchemaValidation) {
|
||||
if resolvedSchema, err := resolveSchemaFromParams(spec.Params, spec.WorkingDir, dag.Location); err != nil {
|
||||
return fmt.Errorf("failed to get JSON schema: %w", err)
|
||||
} else if resolvedSchema != nil {
|
||||
updatedPairs, err := validateParams(paramPairs, resolvedSchema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paramPairs = updatedPairs
|
||||
}
|
||||
}
|
||||
|
||||
for _, paramPair := range paramPairs {
|
||||
dag.Params = append(dag.Params, paramPair.String())
|
||||
}
|
||||
|
||||
dag.Env = append(dag.Env, envs...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateParams(paramPairs []paramPair, schema *jsonschema.Resolved) ([]paramPair, error) {
|
||||
// Convert paramPairs to a map for validation
|
||||
paramMap := make(map[string]any)
|
||||
@ -127,134 +56,6 @@ func validateParams(paramPairs []paramPair, schema *jsonschema.Resolved) ([]para
|
||||
return updatedPairs, nil
|
||||
}
|
||||
|
||||
// Schema Ref can be a local file (relative or absolute paths), or a remote URL
|
||||
func getSchemaFromRef(workingDir string, dagLocation string, schemaRef string) (*jsonschema.Resolved, error) {
|
||||
var schemaData []byte
|
||||
var err error
|
||||
|
||||
// Check if it's a URL or file path
|
||||
if strings.HasPrefix(schemaRef, "http://") || strings.HasPrefix(schemaRef, "https://") {
|
||||
schemaData, err = loadSchemaFromURL(schemaRef)
|
||||
} else {
|
||||
schemaData, err = loadSchemaFromFile(workingDir, dagLocation, schemaRef)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load schema from %s: %w", schemaRef, err)
|
||||
}
|
||||
|
||||
var schema jsonschema.Schema
|
||||
if err := json.Unmarshal(schemaData, &schema); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse schema JSON: %w", err)
|
||||
}
|
||||
|
||||
resolveOptions := &jsonschema.ResolveOptions{
|
||||
ValidateDefaults: true,
|
||||
}
|
||||
|
||||
resolvedSchema, err := schema.Resolve(resolveOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve schema: %w", err)
|
||||
}
|
||||
|
||||
return resolvedSchema, nil
|
||||
}
|
||||
|
||||
// loadSchemaFromURL loads a JSON schema from a URL.
|
||||
func loadSchemaFromURL(schemaURL string) (data []byte, err error) {
|
||||
// Validate URL to prevent potential security issues (and satisfy linter :P)
|
||||
parsedURL, err := url.Parse(schemaURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return nil, fmt.Errorf("unsupported URL scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", schemaURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
data, err = io.ReadAll(resp.Body)
|
||||
return data, err
|
||||
}
|
||||
|
||||
// loadSchemaFromFile loads a JSON schema from a file path.
|
||||
func loadSchemaFromFile(workingDir string, dagLocation string, filePath string) ([]byte, error) {
|
||||
// Try to resolve the schema file path in the following order:
|
||||
// 1) Current working directory (default ResolvePath behavior)
|
||||
// 2) DAG's workingDir value
|
||||
// 3) Directory of the DAG file (where it was loaded from)
|
||||
|
||||
var tried []string
|
||||
|
||||
// Attempts a candidate by joining base and filePath (if base provided),
|
||||
// resolving env/tilde + absolute path, checking existence, and reading.
|
||||
tryCandidate := func(label, base string) ([]byte, string, error) {
|
||||
var candidate string
|
||||
if strings.TrimSpace(base) == "" {
|
||||
candidate = filePath
|
||||
} else {
|
||||
candidate = filepath.Join(base, filePath)
|
||||
}
|
||||
resolved, err := fileutil.ResolvePath(candidate)
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("%s: resolve error: %v", label, err))
|
||||
return nil, "", err
|
||||
}
|
||||
if !fileutil.FileExists(resolved) {
|
||||
tried = append(tried, fmt.Sprintf("%s: %s", label, resolved))
|
||||
return nil, resolved, os.ErrNotExist
|
||||
}
|
||||
data, err := os.ReadFile(resolved) // #nosec G304 - validated path
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("%s: %s (read error: %v)", label, resolved, err))
|
||||
return nil, resolved, err
|
||||
}
|
||||
return data, resolved, nil
|
||||
}
|
||||
|
||||
// 1) As provided (CWD/env/tilde expansion handled by ResolvePath)
|
||||
if data, _, err := tryCandidate("cwd", ""); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 2) From DAG's workingDir value if present
|
||||
if wd := strings.TrimSpace(workingDir); wd != "" {
|
||||
if data, _, err := tryCandidate(fmt.Sprintf("workingDir(%s)", wd), wd); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 3) From the directory of the DAG file used to build
|
||||
if dagLocation != "" {
|
||||
base := filepath.Dir(dagLocation)
|
||||
if data, _, err := tryCandidate(fmt.Sprintf("dagDir(%s)", base), base); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(tried) == 0 {
|
||||
return nil, fmt.Errorf("failed to resolve schema file path: %s (no candidates)", filePath)
|
||||
}
|
||||
return nil, fmt.Errorf("schema file not found for %q; tried %s", filePath, strings.Join(tried, ", "))
|
||||
}
|
||||
|
||||
func overrideParams(paramPairs *[]paramPair, override []paramPair) {
|
||||
// Override the default parameters with the command line parameters (and satisfy linter :P)
|
||||
pairsIndex := make(map[string]int)
|
||||
@ -514,57 +315,3 @@ func (p paramPair) Escaped() string {
|
||||
}
|
||||
return fmt.Sprintf("%q", p.Value)
|
||||
}
|
||||
|
||||
// extractSchemaReference extracts the schema reference from a params map.
|
||||
// Returns the schema reference as a string if present and valid, empty string otherwise.
|
||||
func extractSchemaReference(params any) string {
|
||||
paramsMap, ok := params.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
schemaRef, hasSchema := paramsMap["schema"]
|
||||
if !hasSchema {
|
||||
return ""
|
||||
}
|
||||
|
||||
schemaRefStr, ok := schemaRef.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return schemaRefStr
|
||||
}
|
||||
|
||||
// resolveSchemaFromParams extracts a schema reference from params and resolves it.
|
||||
// Returns (nil, nil) if no schema is declared.
|
||||
func resolveSchemaFromParams(params any, workingDir, dagLocation string) (*jsonschema.Resolved, error) {
|
||||
schemaRef := extractSchemaReference(params)
|
||||
if schemaRef == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return getSchemaFromRef(workingDir, dagLocation, schemaRef)
|
||||
}
|
||||
|
||||
// buildStepParams parses the params field in the step definition.
|
||||
// Params are converted to map[string]string and stored in step.Params
|
||||
func buildStepParams(ctx StepBuildContext, def stepDef, step *core.Step) error {
|
||||
if def.Params == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse params using existing parseParamValue function
|
||||
paramPairs, err := parseParamValue(ctx.BuildContext, def.Params)
|
||||
if err != nil {
|
||||
return core.NewValidationError("params", def.Params, err)
|
||||
}
|
||||
|
||||
// Convert to map[string]string
|
||||
paramsData := make(map[string]string)
|
||||
for _, pair := range paramPairs {
|
||||
paramsData[pair.Name] = pair.Value
|
||||
}
|
||||
|
||||
step.Params = core.NewSimpleParams(paramsData)
|
||||
return nil
|
||||
}
|
||||
|
||||
662
internal/core/spec/params_test.go
Normal file
662
internal/core/spec/params_test.go
Normal file
@ -0,0 +1,662 @@
|
||||
package spec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParamPairString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pair paramPair
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "NamedParam",
|
||||
pair: paramPair{Name: "foo", Value: "bar"},
|
||||
expected: "foo=bar",
|
||||
},
|
||||
{
|
||||
name: "PositionalParam",
|
||||
pair: paramPair{Name: "", Value: "value"},
|
||||
expected: "value",
|
||||
},
|
||||
{
|
||||
name: "EmptyValue",
|
||||
pair: paramPair{Name: "key", Value: ""},
|
||||
expected: "key=",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, tt.expected, tt.pair.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamPairEscaped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pair paramPair
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "NamedParam",
|
||||
pair: paramPair{Name: "foo", Value: "bar"},
|
||||
expected: `foo="bar"`,
|
||||
},
|
||||
{
|
||||
name: "PositionalParam",
|
||||
pair: paramPair{Name: "", Value: "value"},
|
||||
expected: `"value"`,
|
||||
},
|
||||
{
|
||||
name: "ValueWithSpaces",
|
||||
pair: paramPair{Name: "msg", Value: "hello world"},
|
||||
expected: `msg="hello world"`,
|
||||
},
|
||||
{
|
||||
name: "ValueWithQuotes",
|
||||
pair: paramPair{Name: "json", Value: `{"key":"value"}`},
|
||||
expected: `json="{\"key\":\"value\"}"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, tt.expected, tt.pair.Escaped())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStringParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []paramPair
|
||||
}{
|
||||
{
|
||||
name: "SinglePositionalParam",
|
||||
input: "value",
|
||||
expected: []paramPair{{Name: "", Value: "value"}},
|
||||
},
|
||||
{
|
||||
name: "SingleNamedParam",
|
||||
input: "key=value",
|
||||
expected: []paramPair{{Name: "key", Value: "value"}},
|
||||
},
|
||||
{
|
||||
name: "MultipleNamedParams",
|
||||
input: "foo=bar baz=qux",
|
||||
expected: []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MixedParams",
|
||||
input: "positional key=value",
|
||||
expected: []paramPair{
|
||||
{Name: "", Value: "positional"},
|
||||
{Name: "key", Value: "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "QuotedValue",
|
||||
input: `msg="hello world"`,
|
||||
expected: []paramPair{{Name: "msg", Value: "hello world"}},
|
||||
},
|
||||
{
|
||||
name: "QuotedValueWithEscapedQuotes",
|
||||
input: `msg="say \"hello\""`,
|
||||
expected: []paramPair{{Name: "msg", Value: `say "hello\`}},
|
||||
},
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "MultiplePositionalParams",
|
||||
input: "one two three",
|
||||
expected: []paramPair{
|
||||
{Name: "", Value: "one"},
|
||||
{Name: "", Value: "two"},
|
||||
{Name: "", Value: "three"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BacktickValue",
|
||||
input: "cmd=`echo hello`",
|
||||
expected: []paramPair{{Name: "cmd", Value: "`echo hello`"}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseStringParams(ctx, tt.input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStringParamsWithEval(t *testing.T) {
|
||||
t.Run("BacktickCommandSubstitution", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{},
|
||||
}
|
||||
|
||||
result, err := parseStringParams(ctx, "val=`echo hello`")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 1)
|
||||
assert.Equal(t, "val", result[0].Name)
|
||||
assert.Equal(t, "hello", result[0].Value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseListParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
expected []paramPair
|
||||
}{
|
||||
{
|
||||
name: "EmptyList",
|
||||
input: []string{},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "SingleItem",
|
||||
input: []string{"foo=bar"},
|
||||
expected: []paramPair{{Name: "foo", Value: "bar"}},
|
||||
},
|
||||
{
|
||||
name: "MultipleItems",
|
||||
input: []string{"foo=bar", "baz=qux"},
|
||||
expected: []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ItemsWithMultipleParams",
|
||||
input: []string{"a=1 b=2", "c=3"},
|
||||
expected: []paramPair{
|
||||
{Name: "a", Value: "1"},
|
||||
{Name: "b", Value: "2"},
|
||||
{Name: "c", Value: "3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseListParams(ctx, tt.input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMapParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
t.Run("EmptySlice", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseMapParams(ctx, []any{})
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("SingleMap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := []any{
|
||||
map[string]any{"foo": "bar"},
|
||||
}
|
||||
result, err := parseMapParams(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{{Name: "foo", Value: "bar"}}, result)
|
||||
})
|
||||
|
||||
t.Run("MapWithMultipleKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := []any{
|
||||
map[string]any{"a": "1", "b": "2"},
|
||||
}
|
||||
result, err := parseMapParams(ctx, input)
|
||||
require.NoError(t, err)
|
||||
// Keys are sorted alphabetically
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "a", Value: "1"},
|
||||
{Name: "b", Value: "2"},
|
||||
}, result)
|
||||
})
|
||||
|
||||
t.Run("MultipleMaps", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := []any{
|
||||
map[string]any{"foo": "bar"},
|
||||
map[string]any{"baz": "qux"},
|
||||
}
|
||||
result, err := parseMapParams(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
}, result)
|
||||
})
|
||||
|
||||
t.Run("MixedMapAndString", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := []any{
|
||||
map[string]any{"foo": "bar"},
|
||||
"baz=qux",
|
||||
}
|
||||
result, err := parseMapParams(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
}, result)
|
||||
})
|
||||
|
||||
t.Run("IntegerValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := []any{
|
||||
map[string]any{"count": 42},
|
||||
}
|
||||
result, err := parseMapParams(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{{Name: "count", Value: "42"}}, result)
|
||||
})
|
||||
|
||||
t.Run("BooleanValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := []any{
|
||||
map[string]any{"debug": true},
|
||||
}
|
||||
result, err := parseMapParams(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{{Name: "debug", Value: "true"}}, result)
|
||||
})
|
||||
|
||||
t.Run("InvalidType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
input := []any{123}
|
||||
_, err := parseMapParams(ctx, input)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid parameter value")
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseParamValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
t.Run("Nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseParamValue(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseParamValue(ctx, "foo=bar baz=qux")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
}, result)
|
||||
})
|
||||
|
||||
t.Run("StringSlice", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseParamValue(ctx, []string{"foo=bar", "baz=qux"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
}, result)
|
||||
})
|
||||
|
||||
t.Run("AnySlice", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseParamValue(ctx, []any{
|
||||
map[string]any{"foo": "bar"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{{Name: "foo", Value: "bar"}}, result)
|
||||
})
|
||||
|
||||
t.Run("MapWithoutSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseParamValue(ctx, map[string]any{
|
||||
"foo": "bar",
|
||||
"baz": "qux",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Keys are sorted
|
||||
assert.Len(t, result, 2)
|
||||
})
|
||||
|
||||
t.Run("MapWithSchemaNoValues", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseParamValue(ctx, map[string]any{
|
||||
"schema": "schema.json",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("MapWithSchemaAndValues", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := parseParamValue(ctx, map[string]any{
|
||||
"schema": "schema.json",
|
||||
"values": map[string]any{"foo": "bar"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []paramPair{{Name: "foo", Value: "bar"}}, result)
|
||||
})
|
||||
|
||||
t.Run("InvalidType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := parseParamValue(ctx, 123)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid parameter value")
|
||||
})
|
||||
}
|
||||
|
||||
func TestOverrideParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OverrideByName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := []paramPair{
|
||||
{Name: "foo", Value: "original"},
|
||||
{Name: "bar", Value: "keep"},
|
||||
}
|
||||
override := []paramPair{
|
||||
{Name: "foo", Value: "overridden"},
|
||||
}
|
||||
overrideParams(¶ms, override)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "overridden"},
|
||||
{Name: "bar", Value: "keep"},
|
||||
}, params)
|
||||
})
|
||||
|
||||
t.Run("AddNewNamedParam", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
}
|
||||
override := []paramPair{
|
||||
{Name: "baz", Value: "qux"},
|
||||
}
|
||||
overrideParams(¶ms, override)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
}, params)
|
||||
})
|
||||
|
||||
t.Run("OverrideByPosition", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := []paramPair{
|
||||
{Name: "", Value: "first"},
|
||||
{Name: "", Value: "second"},
|
||||
}
|
||||
override := []paramPair{
|
||||
{Name: "", Value: "new-first"},
|
||||
}
|
||||
overrideParams(¶ms, override)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "", Value: "new-first"},
|
||||
{Name: "", Value: "second"},
|
||||
}, params)
|
||||
})
|
||||
|
||||
t.Run("AddPositionalBeyondLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := []paramPair{
|
||||
{Name: "", Value: "first"},
|
||||
}
|
||||
override := []paramPair{
|
||||
{Name: "", Value: "new-first"},
|
||||
{Name: "", Value: "new-second"},
|
||||
}
|
||||
overrideParams(¶ms, override)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "", Value: "new-first"},
|
||||
{Name: "", Value: "new-second"},
|
||||
}, params)
|
||||
})
|
||||
|
||||
t.Run("EmptyOverride", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
}
|
||||
overrideParams(¶ms, []paramPair{})
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
}, params)
|
||||
})
|
||||
|
||||
t.Run("EmptyParams", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
params := []paramPair{}
|
||||
override := []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
}
|
||||
overrideParams(¶ms, override)
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
}, params)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SimpleParams", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
var params []paramPair
|
||||
var envs []string
|
||||
|
||||
err := parseParams(ctx, "foo=bar baz=qux", ¶ms, &envs)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
}, params)
|
||||
// NoEval flag prevents env vars from being added
|
||||
assert.Empty(t, envs)
|
||||
})
|
||||
|
||||
t.Run("PositionalParamsGetNames", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
var params []paramPair
|
||||
var envs []string
|
||||
|
||||
err := parseParams(ctx, "one two three", ¶ms, &envs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Positional params get numbered names
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "1", Value: "one"},
|
||||
{Name: "2", Value: "two"},
|
||||
{Name: "3", Value: "three"},
|
||||
}, params)
|
||||
})
|
||||
|
||||
t.Run("WithEvalAddsEnvs", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{},
|
||||
}
|
||||
|
||||
var params []paramPair
|
||||
var envs []string
|
||||
|
||||
err := parseParams(ctx, "foo=bar baz=qux", ¶ms, &envs)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []paramPair{
|
||||
{Name: "foo", Value: "bar"},
|
||||
{Name: "baz", Value: "qux"},
|
||||
}, params)
|
||||
assert.Equal(t, []string{"foo=bar", "baz=qux"}, envs)
|
||||
})
|
||||
|
||||
t.Run("VariableReference", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{},
|
||||
}
|
||||
|
||||
var params []paramPair
|
||||
var envs []string
|
||||
|
||||
// Second param references first param
|
||||
err := parseParams(ctx, []any{
|
||||
map[string]any{"BASE": "/opt"},
|
||||
map[string]any{"PATH_VAR": "${BASE}/bin"},
|
||||
}, ¶ms, &envs)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "/opt", params[0].Value)
|
||||
assert.Equal(t, "/opt/bin", params[1].Value)
|
||||
})
|
||||
|
||||
t.Run("NilInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
var params []paramPair
|
||||
var envs []string
|
||||
|
||||
err := parseParams(ctx, nil, ¶ms, &envs)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, params)
|
||||
assert.Empty(t, envs)
|
||||
})
|
||||
|
||||
t.Run("InvalidInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
var params []paramPair
|
||||
var envs []string
|
||||
|
||||
err := parseParams(ctx, 123, ¶ms, &envs)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid parameter value")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvalParamValue(t *testing.T) {
|
||||
t.Run("SimpleValue", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
result, err := evalParamValue(ctx, "hello", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", result)
|
||||
})
|
||||
|
||||
t.Run("WithAccumulatedVars", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
}
|
||||
|
||||
vars := map[string]string{"BASE": "/opt"}
|
||||
result, err := evalParamValue(ctx, "${BASE}/bin", vars)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/opt/bin", result)
|
||||
})
|
||||
|
||||
t.Run("WithBuildEnv", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
buildEnv: map[string]string{"ENV_VAR": "value"},
|
||||
}
|
||||
|
||||
result, err := evalParamValue(ctx, "${ENV_VAR}", nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "value", result)
|
||||
})
|
||||
|
||||
t.Run("AccumulatedVarsPrecedence", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
buildEnv: map[string]string{"VAR": "from-env"},
|
||||
}
|
||||
|
||||
vars := map[string]string{"VAR": "from-accumulated"}
|
||||
result, err := evalParamValue(ctx, "${VAR}", vars)
|
||||
require.NoError(t, err)
|
||||
// Accumulated vars should take precedence (added first to options)
|
||||
assert.Equal(t, "from-accumulated", result)
|
||||
})
|
||||
}
|
||||
@ -26,78 +26,3 @@ func buildScheduler(values []string) ([]core.Schedule, error) {
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// parseScheduleMap parses the schedule map and populates the starts, stops,
|
||||
// and restarts slices. Each key in the map must be either "start", "stop", or
|
||||
// "restart". The value can be Case 1 or Case 2.
|
||||
//
|
||||
// Case 1: The value is a string
|
||||
// Case 2: The value is an array of strings
|
||||
//
|
||||
// Example:
|
||||
// ```yaml
|
||||
// schedule:
|
||||
//
|
||||
// start: "0 1 * * *"
|
||||
// stop: "0 18 * * *"
|
||||
// restart:
|
||||
// - "0 1 * * *"
|
||||
// - "0 18 * * *"
|
||||
//
|
||||
// ```
|
||||
func parseScheduleMap(
|
||||
scheduleMap map[string]any, starts, stops, restarts *[]string,
|
||||
) error {
|
||||
for key, v := range scheduleMap {
|
||||
var values []string
|
||||
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
// Case 1. schedule is a string.
|
||||
values = append(values, v)
|
||||
|
||||
case []any:
|
||||
// Case 2. schedule is an array of strings.
|
||||
// Append all the schedules to the values slice.
|
||||
for _, s := range v {
|
||||
s, ok := s.(string)
|
||||
if !ok {
|
||||
return core.NewValidationError("schedule", s, ErrScheduleMustBeStringOrArray)
|
||||
}
|
||||
values = append(values, s)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
var targets *[]string
|
||||
|
||||
switch scheduleKey(key) {
|
||||
case scheduleKeyStart:
|
||||
targets = starts
|
||||
|
||||
case scheduleKeyStop:
|
||||
targets = stops
|
||||
|
||||
case scheduleKeyRestart:
|
||||
targets = restarts
|
||||
|
||||
}
|
||||
|
||||
for _, v := range values {
|
||||
if _, err := cronParser.Parse(v); err != nil {
|
||||
return core.NewValidationError("schedule", v, fmt.Errorf("%w: %s", ErrInvalidSchedule, err))
|
||||
}
|
||||
*targets = append(*targets, v)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type scheduleKey string
|
||||
|
||||
const (
|
||||
scheduleKeyStart scheduleKey = "start"
|
||||
scheduleKeyStop scheduleKey = "stop"
|
||||
scheduleKeyRestart scheduleKey = "restart"
|
||||
)
|
||||
|
||||
174
internal/core/spec/schema.go
Normal file
174
internal/core/spec/schema.go
Normal file
@ -0,0 +1,174 @@
|
||||
package spec
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/common/fileutil"
|
||||
"github.com/google/jsonschema-go/jsonschema"
|
||||
)
|
||||
|
||||
// resolveSchemaFromParams extracts a schema reference from params and resolves it.
|
||||
// Returns (nil, nil) if no schema is declared.
|
||||
func resolveSchemaFromParams(params any, workingDir, dagLocation string) (*jsonschema.Resolved, error) {
|
||||
schemaRef := extractSchemaReference(params)
|
||||
if schemaRef == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return getSchemaFromRef(workingDir, dagLocation, schemaRef)
|
||||
}
|
||||
|
||||
// Schema Ref can be a local file (relative or absolute paths), or a remote URL
|
||||
func getSchemaFromRef(workingDir string, dagLocation string, schemaRef string) (*jsonschema.Resolved, error) {
|
||||
var schemaData []byte
|
||||
var err error
|
||||
|
||||
// Check if it's a URL or file path
|
||||
if strings.HasPrefix(schemaRef, "http://") || strings.HasPrefix(schemaRef, "https://") {
|
||||
schemaData, err = loadSchemaFromURL(schemaRef)
|
||||
} else {
|
||||
schemaData, err = loadSchemaFromFile(workingDir, dagLocation, schemaRef)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load schema from %s: %w", schemaRef, err)
|
||||
}
|
||||
|
||||
var schema jsonschema.Schema
|
||||
if err := json.Unmarshal(schemaData, &schema); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse schema JSON: %w", err)
|
||||
}
|
||||
|
||||
resolveOptions := &jsonschema.ResolveOptions{
|
||||
ValidateDefaults: true,
|
||||
}
|
||||
|
||||
resolvedSchema, err := schema.Resolve(resolveOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve schema: %w", err)
|
||||
}
|
||||
|
||||
return resolvedSchema, nil
|
||||
}
|
||||
|
||||
// loadSchemaFromURL loads a JSON schema from a URL.
|
||||
func loadSchemaFromURL(schemaURL string) (data []byte, err error) {
|
||||
// Validate URL to prevent potential security issues (and satisfy linter :P)
|
||||
parsedURL, err := url.Parse(schemaURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return nil, fmt.Errorf("unsupported URL scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", schemaURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
data, err = io.ReadAll(resp.Body)
|
||||
return data, err
|
||||
}
|
||||
|
||||
// loadSchemaFromFile loads a JSON schema from a file path.
|
||||
func loadSchemaFromFile(workingDir string, dagLocation string, filePath string) ([]byte, error) {
|
||||
// Try to resolve the schema file path in the following order:
|
||||
// 1) Current working directory (default ResolvePath behavior)
|
||||
// 2) DAG's workingDir value
|
||||
// 3) Directory of the DAG file (where it was loaded from)
|
||||
|
||||
var tried []string
|
||||
|
||||
// Attempts a candidate by joining base and filePath (if base provided),
|
||||
// resolving env/tilde + absolute path, checking existence, and reading.
|
||||
tryCandidate := func(label, base string) ([]byte, string, error) {
|
||||
var candidate string
|
||||
if strings.TrimSpace(base) == "" {
|
||||
candidate = filePath
|
||||
} else {
|
||||
candidate = filepath.Join(base, filePath)
|
||||
}
|
||||
resolved, err := fileutil.ResolvePath(candidate)
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("%s: resolve error: %v", label, err))
|
||||
return nil, "", err
|
||||
}
|
||||
if !fileutil.FileExists(resolved) {
|
||||
tried = append(tried, fmt.Sprintf("%s: %s", label, resolved))
|
||||
return nil, resolved, os.ErrNotExist
|
||||
}
|
||||
data, err := os.ReadFile(resolved) // #nosec G304 - validated path
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("%s: %s (read error: %v)", label, resolved, err))
|
||||
return nil, resolved, err
|
||||
}
|
||||
return data, resolved, nil
|
||||
}
|
||||
|
||||
// 1) As provided (CWD/env/tilde expansion handled by ResolvePath)
|
||||
if data, _, err := tryCandidate("cwd", ""); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 2) From DAG's workingDir value if present
|
||||
if wd := strings.TrimSpace(workingDir); wd != "" {
|
||||
if data, _, err := tryCandidate(fmt.Sprintf("workingDir(%s)", wd), wd); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 3) From the directory of the DAG file used to build
|
||||
if dagLocation != "" {
|
||||
base := filepath.Dir(dagLocation)
|
||||
if data, _, err := tryCandidate(fmt.Sprintf("dagDir(%s)", base), base); err == nil {
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(tried) == 0 {
|
||||
return nil, fmt.Errorf("failed to resolve schema file path: %s (no candidates)", filePath)
|
||||
}
|
||||
return nil, fmt.Errorf("schema file not found for %q; tried %s", filePath, strings.Join(tried, ", "))
|
||||
}
|
||||
|
||||
// extractSchemaReference extracts the schema reference from a params map.
|
||||
// Returns the schema reference as a string if present and valid, empty string otherwise.
|
||||
func extractSchemaReference(params any) string {
|
||||
paramsMap, ok := params.(map[string]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
schemaRef, hasSchema := paramsMap["schema"]
|
||||
if !hasSchema {
|
||||
return ""
|
||||
}
|
||||
|
||||
schemaRefStr, ok := schemaRef.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return schemaRefStr
|
||||
}
|
||||
833
internal/core/spec/schema_test.go
Normal file
833
internal/core/spec/schema_test.go
Normal file
@ -0,0 +1,833 @@
|
||||
package spec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractSchemaReference(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
params any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Nil",
|
||||
params: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "NotAMap",
|
||||
params: "string value",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "EmptyMap",
|
||||
params: map[string]any{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "NoSchemaKey",
|
||||
params: map[string]any{"values": map[string]any{"foo": "bar"}},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "SchemaKeyNotString",
|
||||
params: map[string]any{"schema": 123},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "SchemaKeyIsMap",
|
||||
params: map[string]any{"schema": map[string]any{"type": "object"}},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "ValidSchemaReference",
|
||||
params: map[string]any{"schema": "schema.json"},
|
||||
expected: "schema.json",
|
||||
},
|
||||
{
|
||||
name: "ValidSchemaReferenceWithValues",
|
||||
params: map[string]any{"schema": "./schemas/params.json", "values": map[string]any{"foo": "bar"}},
|
||||
expected: "./schemas/params.json",
|
||||
},
|
||||
{
|
||||
name: "HTTPSchemaReference",
|
||||
params: map[string]any{"schema": "https://example.com/schema.json"},
|
||||
expected: "https://example.com/schema.json",
|
||||
},
|
||||
{
|
||||
name: "EmptySchemaString",
|
||||
params: map[string]any{"schema": ""},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := extractSchemaReference(tt.params)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSchemaFromURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SuccessfulLoad", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemaContent := `{"type": "object", "properties": {"foo": {"type": "string"}}}`
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(schemaContent))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
data, err := loadSchemaFromURL(server.URL + "/schema.json")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, schemaContent, string(data))
|
||||
})
|
||||
|
||||
t.Run("HTTPError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := loadSchemaFromURL(server.URL + "/missing.json")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "404")
|
||||
})
|
||||
|
||||
t.Run("ServerError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := loadSchemaFromURL(server.URL + "/schema.json")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "500")
|
||||
})
|
||||
|
||||
t.Run("InvalidURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadSchemaFromURL("://invalid-url")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid")
|
||||
})
|
||||
|
||||
t.Run("UnsupportedScheme", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadSchemaFromURL("ftp://example.com/schema.json")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported URL scheme")
|
||||
})
|
||||
|
||||
t.Run("ConnectionRefused", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a port that's unlikely to be in use
|
||||
_, err := loadSchemaFromURL("http://127.0.0.1:59999/schema.json")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadSchemaFromFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemaContent := `{"type": "object"}`
|
||||
|
||||
t.Run("AbsolutePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
schemaPath := filepath.Join(tmpDir, "schema.json")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
|
||||
|
||||
data, err := loadSchemaFromFile("", "", schemaPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, schemaContent, string(data))
|
||||
})
|
||||
|
||||
t.Run("FromWorkingDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
workingDir := t.TempDir()
|
||||
schemaPath := filepath.Join(workingDir, "schema.json")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
|
||||
|
||||
data, err := loadSchemaFromFile(workingDir, "", "schema.json")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, schemaContent, string(data))
|
||||
})
|
||||
|
||||
t.Run("FromDAGDirectory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dagDir := t.TempDir()
|
||||
schemaPath := filepath.Join(dagDir, "schema.json")
|
||||
dagPath := filepath.Join(dagDir, "dag.yaml")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
|
||||
|
||||
data, err := loadSchemaFromFile("", dagPath, "schema.json")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, schemaContent, string(data))
|
||||
})
|
||||
|
||||
t.Run("WorkingDirTakesPrecedenceOverDAGDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
workingDir := t.TempDir()
|
||||
dagDir := t.TempDir()
|
||||
|
||||
wdSchema := `{"type": "object", "title": "working-dir"}`
|
||||
dagSchema := `{"type": "object", "title": "dag-dir"}`
|
||||
|
||||
require.NoError(t, os.WriteFile(filepath.Join(workingDir, "schema.json"), []byte(wdSchema), 0600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dagDir, "schema.json"), []byte(dagSchema), 0600))
|
||||
|
||||
dagPath := filepath.Join(dagDir, "dag.yaml")
|
||||
data, err := loadSchemaFromFile(workingDir, dagPath, "schema.json")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, wdSchema, string(data))
|
||||
})
|
||||
|
||||
t.Run("FileNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadSchemaFromFile("", "", "nonexistent-schema.json")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("FileNotFoundWithWorkingDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
workingDir := t.TempDir()
|
||||
_, err := loadSchemaFromFile(workingDir, "", "nonexistent.json")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("FileNotFoundWithDAGDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dagDir := t.TempDir()
|
||||
dagPath := filepath.Join(dagDir, "dag.yaml")
|
||||
_, err := loadSchemaFromFile("", dagPath, "nonexistent.json")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
|
||||
t.Run("SubdirectoryPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
workingDir := t.TempDir()
|
||||
schemasDir := filepath.Join(workingDir, "schemas")
|
||||
require.NoError(t, os.MkdirAll(schemasDir, 0755))
|
||||
schemaPath := filepath.Join(schemasDir, "params.json")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
|
||||
|
||||
data, err := loadSchemaFromFile(workingDir, "", "schemas/params.json")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, schemaContent, string(data))
|
||||
})
|
||||
|
||||
t.Run("EmptyWorkingDirAndDAGLocation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := loadSchemaFromFile("", "", "schema.json")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("WhitespaceOnlyWorkingDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dagDir := t.TempDir()
|
||||
schemaPath := filepath.Join(dagDir, "schema.json")
|
||||
dagPath := filepath.Join(dagDir, "dag.yaml")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(schemaContent), 0600))
|
||||
|
||||
// Whitespace-only workingDir should be skipped, fall back to dagDir
|
||||
data, err := loadSchemaFromFile(" ", dagPath, "schema.json")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, schemaContent, string(data))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetSchemaFromRef(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validSchemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"}
|
||||
}
|
||||
}`
|
||||
|
||||
t.Run("LocalFileSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
schemaPath := filepath.Join(tmpDir, "schema.json")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(validSchemaContent), 0600))
|
||||
|
||||
resolved, err := getSchemaFromRef("", "", schemaPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resolved)
|
||||
})
|
||||
|
||||
t.Run("RemoteURLSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(validSchemaContent))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resolved, err := getSchemaFromRef("", "", server.URL+"/schema.json")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resolved)
|
||||
})
|
||||
|
||||
t.Run("HTTPSSchemaReference", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(validSchemaContent))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// This will fail due to self-signed cert, but tests the https:// detection
|
||||
_, err := getSchemaFromRef("", "", server.URL+"/schema.json")
|
||||
// We expect an error due to certificate verification
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("InvalidJSONSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
schemaPath := filepath.Join(tmpDir, "invalid.json")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte("not valid json"), 0600))
|
||||
|
||||
_, err := getSchemaFromRef("", "", schemaPath)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "parse schema JSON")
|
||||
})
|
||||
|
||||
t.Run("SchemaFileNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := getSchemaFromRef("", "", "nonexistent.json")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to load schema")
|
||||
})
|
||||
|
||||
t.Run("URLNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := getSchemaFromRef("", "", server.URL+"/missing.json")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to load schema")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveSchemaFromParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
validSchemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {"type": "integer", "default": 10}
|
||||
}
|
||||
}`
|
||||
|
||||
t.Run("NoSchemaReference", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resolved, err := resolveSchemaFromParams(nil, "", "")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resolved)
|
||||
})
|
||||
|
||||
t.Run("ParamsNotMap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resolved, err := resolveSchemaFromParams("string params", "", "")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resolved)
|
||||
})
|
||||
|
||||
t.Run("ParamsWithoutSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := map[string]any{
|
||||
"values": map[string]any{"foo": "bar"},
|
||||
}
|
||||
resolved, err := resolveSchemaFromParams(params, "", "")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resolved)
|
||||
})
|
||||
|
||||
t.Run("ParamsWithValidSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
schemaPath := filepath.Join(tmpDir, "schema.json")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(validSchemaContent), 0600))
|
||||
|
||||
params := map[string]any{
|
||||
"schema": schemaPath,
|
||||
"values": map[string]any{"batch_size": 20},
|
||||
}
|
||||
resolved, err := resolveSchemaFromParams(params, "", "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resolved)
|
||||
})
|
||||
|
||||
t.Run("ParamsWithInvalidSchemaPath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
params := map[string]any{
|
||||
"schema": "nonexistent.json",
|
||||
}
|
||||
_, err := resolveSchemaFromParams(params, "", "")
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("ParamsWithRemoteSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(validSchemaContent))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
params := map[string]any{
|
||||
"schema": server.URL + "/schema.json",
|
||||
}
|
||||
resolved, err := resolveSchemaFromParams(params, "", "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resolved)
|
||||
})
|
||||
|
||||
t.Run("UsesWorkingDirForResolution", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
workingDir := t.TempDir()
|
||||
schemaPath := filepath.Join(workingDir, "schema.json")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(validSchemaContent), 0600))
|
||||
|
||||
params := map[string]any{
|
||||
"schema": "schema.json",
|
||||
}
|
||||
resolved, err := resolveSchemaFromParams(params, workingDir, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resolved)
|
||||
})
|
||||
|
||||
t.Run("UsesDAGLocationForResolution", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dagDir := t.TempDir()
|
||||
schemaPath := filepath.Join(dagDir, "schema.json")
|
||||
dagPath := filepath.Join(dagDir, "dag.yaml")
|
||||
require.NoError(t, os.WriteFile(schemaPath, []byte(validSchemaContent), 0600))
|
||||
|
||||
params := map[string]any{
|
||||
"schema": "schema.json",
|
||||
}
|
||||
resolved, err := resolveSchemaFromParams(params, "", dagPath)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resolved)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildParamsWithLocalSchemaReference(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {
|
||||
"type": "integer",
|
||||
"default": 10,
|
||||
"minimum": 1
|
||||
},
|
||||
"environment": {
|
||||
"type": "string",
|
||||
"default": "dev",
|
||||
"enum": ["dev", "staging", "prod"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "test-schema-*.json")
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
|
||||
_, err = tmpFile.WriteString(schemaContent)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tmpFile.Close())
|
||||
|
||||
data := []byte(fmt.Sprintf(`
|
||||
params:
|
||||
schema: "%s"
|
||||
values:
|
||||
batch_size: 25
|
||||
environment: "staging"
|
||||
`, tmpFile.Name()))
|
||||
|
||||
dag, err := LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, dag.Params, 2)
|
||||
require.Contains(t, dag.Params, "batch_size=25")
|
||||
require.Contains(t, dag.Params, "environment=staging")
|
||||
}
|
||||
|
||||
func TestBuildParamsWithRemoteSchemaReference(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/schemas/dag-params.json", func(w http.ResponseWriter, _ *http.Request) {
|
||||
schemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {
|
||||
"type": "integer",
|
||||
"default": 10,
|
||||
"minimum": 1
|
||||
},
|
||||
"environment": {
|
||||
"type": "string",
|
||||
"default": "dev",
|
||||
"enum": ["dev", "staging", "prod"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(schemaContent))
|
||||
})
|
||||
|
||||
server := httptest.NewServer(mux)
|
||||
defer server.Close()
|
||||
|
||||
data := []byte(fmt.Sprintf(`
|
||||
params:
|
||||
schema: "%s/schemas/dag-params.json"
|
||||
values:
|
||||
batch_size: 50
|
||||
environment: "prod"
|
||||
`, server.URL))
|
||||
|
||||
dag, err := LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, dag.Params, 2)
|
||||
require.Contains(t, dag.Params, "batch_size=50")
|
||||
require.Contains(t, dag.Params, "environment=prod")
|
||||
}
|
||||
|
||||
func TestBuildParamsSchemaResolution(t *testing.T) {
|
||||
t.Run("FromWorkingDir", func(t *testing.T) {
|
||||
schemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {"type": "integer", "default": 42}
|
||||
}
|
||||
}`
|
||||
|
||||
wd := t.TempDir()
|
||||
wdSchema := filepath.Join(wd, "schema.json")
|
||||
require.NoError(t, os.WriteFile(wdSchema, []byte(schemaContent), 0600))
|
||||
|
||||
origWD, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
if err := os.Chdir(origWD); err != nil {
|
||||
t.Fatalf("failed to restore working directory: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
data := []byte(fmt.Sprintf(`
|
||||
workingDir: %s
|
||||
params:
|
||||
schema: "schema.json"
|
||||
values:
|
||||
environment: "dev"
|
||||
`, wd))
|
||||
|
||||
dag, err := LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, dag.Params, "batch_size=42")
|
||||
})
|
||||
|
||||
t.Run("FromDAGDir", func(t *testing.T) {
|
||||
schemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {"type": "integer", "default": 7}
|
||||
}
|
||||
}`
|
||||
|
||||
dir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "schema.json"), []byte(schemaContent), 0600))
|
||||
|
||||
dagYaml := []byte(`
|
||||
params:
|
||||
schema: "schema.json"
|
||||
values:
|
||||
environment: "staging"
|
||||
`)
|
||||
dagPath := filepath.Join(dir, "dag.yaml")
|
||||
require.NoError(t, os.WriteFile(dagPath, dagYaml, 0600))
|
||||
|
||||
dag, err := Load(context.Background(), dagPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, dag.Params, "batch_size=7")
|
||||
})
|
||||
|
||||
t.Run("PrefersCWDOverWorkingDir", func(t *testing.T) {
|
||||
cwdSchemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {"type": "integer", "default": 99}
|
||||
}
|
||||
}`
|
||||
wdSchemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {"type": "integer", "default": 11}
|
||||
}
|
||||
}`
|
||||
|
||||
cwd := t.TempDir()
|
||||
wd := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cwd, "schema.json"), []byte(cwdSchemaContent), 0600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(wd, "schema.json"), []byte(wdSchemaContent), 0600))
|
||||
|
||||
orig, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.Chdir(cwd))
|
||||
defer func() { _ = os.Chdir(orig) }()
|
||||
|
||||
data := []byte(fmt.Sprintf(`
|
||||
workingDir: %s
|
||||
params:
|
||||
schema: "schema.json"
|
||||
values:
|
||||
environment: "dev"
|
||||
`, wd))
|
||||
|
||||
dag, err := LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, dag.Params, "batch_size=99")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildParamsSchemaValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("SkipSchemaValidationFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := []byte(`
|
||||
params:
|
||||
schema: "missing-schema.json"
|
||||
values:
|
||||
foo: "bar"
|
||||
`)
|
||||
_, err := LoadYAML(context.Background(), data)
|
||||
require.Error(t, err)
|
||||
|
||||
dag, err := LoadYAMLWithOpts(context.Background(), data, BuildOpts{
|
||||
Flags: BuildFlagSkipSchemaValidation,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, dag.Params, 1)
|
||||
require.Contains(t, dag.Params, "foo=bar")
|
||||
})
|
||||
|
||||
t.Run("OverrideValidationFails", func(t *testing.T) {
|
||||
schemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {
|
||||
"type": "integer",
|
||||
"default": 10,
|
||||
"minimum": 1,
|
||||
"maximum": 50
|
||||
},
|
||||
"environment": {
|
||||
"type": "string",
|
||||
"default": "dev",
|
||||
"enum": ["dev", "staging", "prod"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "test-schema-validation-*.json")
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
|
||||
_, err = tmpFile.WriteString(schemaContent)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tmpFile.Close())
|
||||
|
||||
data := []byte(fmt.Sprintf(`
|
||||
params:
|
||||
schema: "%s"
|
||||
`, tmpFile.Name()))
|
||||
|
||||
cliParams := "batch_size=100 environment=prod"
|
||||
_, err = LoadYAML(context.Background(), data, WithParams(cliParams))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "parameter validation failed")
|
||||
require.Contains(t, err.Error(), "maximum: 100/1 is greater than 50")
|
||||
})
|
||||
|
||||
t.Run("DefaultsApplied", func(t *testing.T) {
|
||||
schemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {
|
||||
"type": "integer",
|
||||
"default": 25,
|
||||
"minimum": 1,
|
||||
"maximum": 100
|
||||
},
|
||||
"environment": {
|
||||
"type": "string",
|
||||
"default": "development",
|
||||
"enum": ["development", "staging", "production"]
|
||||
},
|
||||
"debug": {
|
||||
"type": "boolean",
|
||||
"default": true
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "test-schema-defaults-*.json")
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
|
||||
_, err = tmpFile.WriteString(schemaContent)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tmpFile.Close())
|
||||
|
||||
data := []byte(fmt.Sprintf(`
|
||||
params:
|
||||
schema: "%s"
|
||||
values:
|
||||
batch_size: 75
|
||||
`, tmpFile.Name()))
|
||||
|
||||
dag, err := LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, dag.Params, 3)
|
||||
require.Contains(t, dag.Params, "batch_size=75")
|
||||
require.Contains(t, dag.Params, "environment=development")
|
||||
require.Contains(t, dag.Params, "debug=true")
|
||||
})
|
||||
|
||||
t.Run("DefaultsPreserveExistingValues", func(t *testing.T) {
|
||||
schemaContent := `{
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"batch_size": {
|
||||
"type": "integer",
|
||||
"default": 25,
|
||||
"minimum": 1,
|
||||
"maximum": 100
|
||||
},
|
||||
"environment": {
|
||||
"type": "string",
|
||||
"default": "development",
|
||||
"enum": ["development", "staging", "production"]
|
||||
},
|
||||
"debug": {
|
||||
"type": "boolean",
|
||||
"default": true
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"default": 300
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
tmpFile, err := os.CreateTemp("", "test-schema-preserve-*.json")
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
|
||||
_, err = tmpFile.WriteString(schemaContent)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tmpFile.Close())
|
||||
|
||||
data := []byte(fmt.Sprintf(`
|
||||
params:
|
||||
schema: "%s"
|
||||
values:
|
||||
batch_size: 50
|
||||
environment: "production"
|
||||
debug: false
|
||||
timeout: 600
|
||||
`, tmpFile.Name()))
|
||||
|
||||
dag, err := LoadYAML(context.Background(), data)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, dag.Params, 4)
|
||||
require.Contains(t, dag.Params, "batch_size=50")
|
||||
require.Contains(t, dag.Params, "environment=production")
|
||||
require.Contains(t, dag.Params, "debug=false")
|
||||
require.Contains(t, dag.Params, "timeout=600")
|
||||
})
|
||||
}
|
||||
1214
internal/core/spec/step.go
Normal file
1214
internal/core/spec/step.go
Normal file
File diff suppressed because it is too large
Load Diff
3094
internal/core/spec/step_test.go
Normal file
3094
internal/core/spec/step_test.go
Normal file
File diff suppressed because it is too large
Load Diff
207
internal/core/spec/types/continueon.go
Normal file
207
internal/core/spec/types/continueon.go
Normal file
@ -0,0 +1,207 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// ContinueOnValue represents a continue-on configuration that can be specified as:
|
||||
// - A string shorthand: "skipped" or "failed"
|
||||
// - A detailed map with configuration options
|
||||
//
|
||||
// YAML examples:
|
||||
//
|
||||
// continueOn: skipped
|
||||
// continueOn: failed
|
||||
// continueOn:
|
||||
// skipped: true
|
||||
// failed: true
|
||||
// exitCode: [0, 1]
|
||||
// output: ["pattern1", "pattern2"]
|
||||
// markSuccess: true
|
||||
type ContinueOnValue struct {
|
||||
raw any // Original value for error reporting
|
||||
isSet bool // Whether the field was set in YAML
|
||||
skipped bool // Continue on skipped
|
||||
failed bool // Continue on failed
|
||||
exitCode []int // Specific exit codes to continue on
|
||||
output []string // Output patterns to match
|
||||
markSuccess bool // Mark step as success when condition is met
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
|
||||
func (c *ContinueOnValue) UnmarshalYAML(data []byte) error {
|
||||
c.isSet = true
|
||||
|
||||
var raw any
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("continueOn unmarshal error: %w", err)
|
||||
}
|
||||
c.raw = raw
|
||||
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
// String shorthand: "skipped" or "failed"
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "skipped":
|
||||
c.skipped = true
|
||||
case "failed":
|
||||
c.failed = true
|
||||
default:
|
||||
return fmt.Errorf("continueOn: expected 'skipped' or 'failed', got %q", v)
|
||||
}
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
// Detailed configuration
|
||||
return c.parseMap(v)
|
||||
|
||||
case nil:
|
||||
c.isSet = false
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("continueOn must be string or map, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ContinueOnValue) parseMap(m map[string]any) error {
|
||||
for key, v := range m {
|
||||
switch key {
|
||||
case "skipped":
|
||||
if b, ok := v.(bool); ok {
|
||||
c.skipped = b
|
||||
} else {
|
||||
return fmt.Errorf("continueOn.skipped: expected boolean, got %T", v)
|
||||
}
|
||||
case "failed", "failure":
|
||||
if b, ok := v.(bool); ok {
|
||||
c.failed = b
|
||||
} else {
|
||||
return fmt.Errorf("continueOn.%s: expected boolean, got %T", key, v)
|
||||
}
|
||||
case "exitCode":
|
||||
codes, err := parseIntArray(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("continueOn.exitCode: %w", err)
|
||||
}
|
||||
c.exitCode = codes
|
||||
case "output":
|
||||
outputs, err := parseStringArray(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("continueOn.output: %w", err)
|
||||
}
|
||||
c.output = outputs
|
||||
case "markSuccess":
|
||||
if b, ok := v.(bool); ok {
|
||||
c.markSuccess = b
|
||||
} else {
|
||||
return fmt.Errorf("continueOn.markSuccess: expected boolean, got %T", v)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("continueOn: unknown key %q", key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseStringArray(v any) ([]string, error) {
|
||||
switch val := v.(type) {
|
||||
case nil:
|
||||
return nil, nil
|
||||
case string:
|
||||
if val == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{val}, nil
|
||||
case []any:
|
||||
var result []string
|
||||
for i, item := range val {
|
||||
s, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[%d]: expected string, got %T", i, item)
|
||||
}
|
||||
result = append(result, s)
|
||||
}
|
||||
return result, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("expected string or array of strings, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func parseIntArray(v any) ([]int, error) {
|
||||
switch val := v.(type) {
|
||||
case []any:
|
||||
var result []int
|
||||
for i, item := range val {
|
||||
n, err := parseIntValue(item)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("[%d]: %w", i, err)
|
||||
}
|
||||
result = append(result, n)
|
||||
}
|
||||
return result, nil
|
||||
case int:
|
||||
return []int{val}, nil
|
||||
case int64:
|
||||
return []int{int(val)}, nil
|
||||
case float64:
|
||||
return []int{int(val)}, nil
|
||||
case uint64:
|
||||
// Exit codes are small numbers, overflow won't happen in practice
|
||||
return []int{int(val)}, nil //nolint:gosec // Exit codes are small numbers
|
||||
case string:
|
||||
n, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse %q as int: %w", val, err)
|
||||
}
|
||||
return []int{n}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("expected int or array of ints, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func parseIntValue(item any) (int, error) {
|
||||
switch v := item.(type) {
|
||||
case int:
|
||||
return v, nil
|
||||
case int64:
|
||||
return int(v), nil
|
||||
case float64:
|
||||
return int(v), nil
|
||||
case uint64:
|
||||
return int(v), nil //nolint:gosec // Exit codes are small numbers
|
||||
case string:
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot parse %q as int: %w", v, err)
|
||||
}
|
||||
return n, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("expected int, got %T", item)
|
||||
}
|
||||
}
|
||||
|
||||
// IsZero returns true if continueOn was not set in YAML.
|
||||
func (c ContinueOnValue) IsZero() bool { return !c.isSet }
|
||||
|
||||
// Value returns the original raw value for error reporting.
|
||||
func (c ContinueOnValue) Value() any { return c.raw }
|
||||
|
||||
// Skipped returns true if should continue on skipped.
|
||||
func (c ContinueOnValue) Skipped() bool { return c.skipped }
|
||||
|
||||
// Failed returns true if should continue on failed.
|
||||
func (c ContinueOnValue) Failed() bool { return c.failed }
|
||||
|
||||
// ExitCode returns exit codes to continue on.
|
||||
func (c ContinueOnValue) ExitCode() []int { return c.exitCode }
|
||||
|
||||
// Output returns output patterns to match.
|
||||
func (c ContinueOnValue) Output() []string { return c.output }
|
||||
|
||||
// MarkSuccess returns true if step should be marked as success when condition is met.
|
||||
func (c ContinueOnValue) MarkSuccess() bool { return c.markSuccess }
|
||||
413
internal/core/spec/types/continueon_test.go
Normal file
413
internal/core/spec/types/continueon_test.go
Normal file
@ -0,0 +1,413 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestContinueOnValue_UnmarshalYAML(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantSkipped bool
|
||||
wantFailed bool
|
||||
wantExitCode []int
|
||||
wantOutput []string
|
||||
wantMarkSuccess bool
|
||||
checkIsZero bool
|
||||
checkNotZero bool
|
||||
}{
|
||||
{
|
||||
name: "StringSkipped",
|
||||
input: "skipped",
|
||||
wantSkipped: true,
|
||||
wantFailed: false,
|
||||
checkNotZero: true,
|
||||
},
|
||||
{
|
||||
name: "StringFailed",
|
||||
input: "failed",
|
||||
wantSkipped: false,
|
||||
wantFailed: true,
|
||||
},
|
||||
{
|
||||
name: "StringCaseInsensitiveSKIPPED",
|
||||
input: "SKIPPED",
|
||||
wantSkipped: true,
|
||||
},
|
||||
{
|
||||
name: "StringCaseInsensitiveFailed",
|
||||
input: "Failed",
|
||||
wantFailed: true,
|
||||
},
|
||||
{
|
||||
name: "StringWithWhitespace",
|
||||
input: `" skipped "`,
|
||||
wantSkipped: true,
|
||||
},
|
||||
{
|
||||
name: "MapFormSkippedOnly",
|
||||
input: "skipped: true",
|
||||
wantSkipped: true,
|
||||
wantFailed: false,
|
||||
},
|
||||
{
|
||||
name: "MapFormFailedOnly",
|
||||
input: "failed: true",
|
||||
wantSkipped: false,
|
||||
wantFailed: true,
|
||||
},
|
||||
{
|
||||
name: "MapFormBoth",
|
||||
input: `
|
||||
skipped: true
|
||||
failed: true
|
||||
`,
|
||||
wantSkipped: true,
|
||||
wantFailed: true,
|
||||
},
|
||||
{
|
||||
name: "MapWithExitCodesArray",
|
||||
input: "exitCode: [0, 1, 2]",
|
||||
wantExitCode: []int{0, 1, 2},
|
||||
},
|
||||
{
|
||||
name: "MapWithSingleExitCode",
|
||||
input: "exitCode: 1",
|
||||
wantExitCode: []int{1},
|
||||
},
|
||||
{
|
||||
name: "MapWithOutputPattern",
|
||||
input: `output: "success|warning"`,
|
||||
wantOutput: []string{"success|warning"},
|
||||
},
|
||||
{
|
||||
name: "MapWithAllFields",
|
||||
input: `
|
||||
skipped: true
|
||||
failed: true
|
||||
exitCode: [0, 1]
|
||||
output: "OK"
|
||||
markSuccess: true
|
||||
`,
|
||||
wantSkipped: true,
|
||||
wantFailed: true,
|
||||
wantExitCode: []int{0, 1},
|
||||
wantOutput: []string{"OK"},
|
||||
wantMarkSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidStringValue",
|
||||
input: "invalid",
|
||||
wantErr: true,
|
||||
errContains: "expected 'skipped' or 'failed'",
|
||||
},
|
||||
{
|
||||
name: "InvalidMapKey",
|
||||
input: "unknown: true",
|
||||
wantErr: true,
|
||||
errContains: "unknown key",
|
||||
},
|
||||
{
|
||||
name: "InvalidSkippedType",
|
||||
input: `skipped: "yes"`,
|
||||
wantErr: true,
|
||||
errContains: "expected bool",
|
||||
},
|
||||
{
|
||||
name: "InvalidExitCodeType",
|
||||
input: `exitCode: "not a number"`,
|
||||
wantErr: true,
|
||||
errContains: "cannot parse",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeArray",
|
||||
input: "[1, 2, 3]",
|
||||
wantErr: true,
|
||||
errContains: "must be string or map",
|
||||
},
|
||||
{
|
||||
name: "OutputAsStringArray",
|
||||
input: `output: ["success", "warning", "info"]`,
|
||||
wantOutput: []string{"success", "warning", "info"},
|
||||
},
|
||||
{
|
||||
name: "ExitCodeAsInt64",
|
||||
input: "exitCode: 255",
|
||||
wantExitCode: []int{255},
|
||||
},
|
||||
{
|
||||
name: "ExitCodeAsString",
|
||||
input: `exitCode: "42"`,
|
||||
wantExitCode: []int{42},
|
||||
},
|
||||
{
|
||||
name: "ExitCodeArrayWithMixedTypes",
|
||||
input: `exitCode: [0, "1", 2]`,
|
||||
wantExitCode: []int{0, 1, 2},
|
||||
},
|
||||
{
|
||||
name: "OutputInvalidTypeInArray",
|
||||
input: `output: [123, true]`,
|
||||
wantErr: true,
|
||||
errContains: "expected string",
|
||||
},
|
||||
{
|
||||
name: "ExitCodeInvalidString",
|
||||
input: `exitCode: "not-a-number"`,
|
||||
wantErr: true,
|
||||
errContains: "cannot parse",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c types.ContinueOnValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &c)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.checkIsZero {
|
||||
assert.True(t, c.IsZero())
|
||||
return
|
||||
}
|
||||
if tt.checkNotZero {
|
||||
assert.False(t, c.IsZero())
|
||||
}
|
||||
if tt.wantSkipped {
|
||||
assert.True(t, c.Skipped())
|
||||
}
|
||||
if tt.wantFailed {
|
||||
assert.True(t, c.Failed())
|
||||
}
|
||||
if tt.wantExitCode != nil {
|
||||
assert.Equal(t, tt.wantExitCode, c.ExitCode())
|
||||
}
|
||||
if tt.wantOutput != nil {
|
||||
assert.Equal(t, tt.wantOutput, c.Output())
|
||||
}
|
||||
if tt.wantMarkSuccess {
|
||||
assert.True(t, c.MarkSuccess())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ZeroValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c types.ContinueOnValue
|
||||
assert.True(t, c.IsZero())
|
||||
assert.False(t, c.Skipped())
|
||||
assert.False(t, c.Failed())
|
||||
})
|
||||
}
|
||||
|
||||
func TestContinueOnValue_InStruct(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type StepConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
ContinueOn types.ContinueOnValue `yaml:"continueOn"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantSkipped bool
|
||||
wantFailed bool
|
||||
wantExitCode []int
|
||||
wantIsZero bool
|
||||
}{
|
||||
{
|
||||
name: "ContinueOnAsString",
|
||||
input: `
|
||||
name: my-step
|
||||
continueOn: skipped
|
||||
`,
|
||||
wantSkipped: true,
|
||||
},
|
||||
{
|
||||
name: "ContinueOnAsMap",
|
||||
input: `
|
||||
name: my-step
|
||||
continueOn:
|
||||
failed: true
|
||||
exitCode: [0, 1]
|
||||
`,
|
||||
wantFailed: true,
|
||||
wantExitCode: []int{0, 1},
|
||||
},
|
||||
{
|
||||
name: "ContinueOnNotSet",
|
||||
input: "name: my-step",
|
||||
wantIsZero: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cfg StepConfig
|
||||
err := yaml.Unmarshal([]byte(tt.input), &cfg)
|
||||
require.NoError(t, err)
|
||||
if tt.wantIsZero {
|
||||
assert.True(t, cfg.ContinueOn.IsZero())
|
||||
return
|
||||
}
|
||||
if tt.wantSkipped {
|
||||
assert.True(t, cfg.ContinueOn.Skipped())
|
||||
}
|
||||
if tt.wantFailed {
|
||||
assert.True(t, cfg.ContinueOn.Failed())
|
||||
}
|
||||
if tt.wantExitCode != nil {
|
||||
assert.Equal(t, tt.wantExitCode, cfg.ContinueOn.ExitCode())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinueOnValue_Value(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ValueReturnsRawString", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c types.ContinueOnValue
|
||||
err := yaml.Unmarshal([]byte("skipped"), &c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "skipped", c.Value())
|
||||
})
|
||||
|
||||
t.Run("ValueReturnsMap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c types.ContinueOnValue
|
||||
err := yaml.Unmarshal([]byte("failed: true"), &c)
|
||||
require.NoError(t, err)
|
||||
val, ok := c.Value().(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, true, val["failed"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestContinueOnValue_EdgeCases(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
checkIsZero bool
|
||||
wantFailed bool
|
||||
wantExitCode []int
|
||||
wantOutputNil bool
|
||||
}{
|
||||
{
|
||||
name: "NullValue",
|
||||
input: "null",
|
||||
checkIsZero: true,
|
||||
},
|
||||
{
|
||||
name: "FailureKeyAlias",
|
||||
input: "failure: true",
|
||||
wantFailed: true,
|
||||
},
|
||||
{
|
||||
name: "ExitCodeAsFloat",
|
||||
input: "exitCode: 1.0",
|
||||
wantExitCode: []int{1},
|
||||
},
|
||||
{
|
||||
name: "ExitCodeArrayWithFloat",
|
||||
input: "exitCode: [1.0, 2.0]",
|
||||
wantExitCode: []int{1, 2},
|
||||
},
|
||||
{
|
||||
name: "InvalidExitCodeTypeInArray",
|
||||
input: "exitCode: [true]",
|
||||
wantErr: true,
|
||||
errContains: "expected int",
|
||||
},
|
||||
{
|
||||
name: "InvalidExitCodeTypeNotIntOrArray",
|
||||
input: "exitCode: {key: value}",
|
||||
wantErr: true,
|
||||
errContains: "expected int or array",
|
||||
},
|
||||
{
|
||||
name: "OutputAsNil",
|
||||
input: "output: null",
|
||||
wantOutputNil: true,
|
||||
},
|
||||
{
|
||||
name: "OutputAsEmptyString",
|
||||
input: `output: ""`,
|
||||
wantOutputNil: true,
|
||||
},
|
||||
{
|
||||
name: "OutputInvalidTypeNotStringOrArray",
|
||||
input: "output: 123",
|
||||
wantErr: true,
|
||||
errContains: "expected string or array",
|
||||
},
|
||||
{
|
||||
name: "MarkSuccessInvalidType",
|
||||
input: `markSuccess: "yes"`,
|
||||
wantErr: true,
|
||||
errContains: "expected bool",
|
||||
},
|
||||
{
|
||||
name: "FailedInvalidType",
|
||||
input: `failed: "yes"`,
|
||||
wantErr: true,
|
||||
errContains: "expected bool",
|
||||
},
|
||||
{
|
||||
name: "ExitCodeInvalidStringInArray",
|
||||
input: `exitCode: ["not-a-number"]`,
|
||||
wantErr: true,
|
||||
errContains: "cannot parse",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var c types.ContinueOnValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &c)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.checkIsZero {
|
||||
assert.True(t, c.IsZero())
|
||||
}
|
||||
if tt.wantFailed {
|
||||
assert.True(t, c.Failed())
|
||||
}
|
||||
if tt.wantExitCode != nil {
|
||||
assert.Equal(t, tt.wantExitCode, c.ExitCode())
|
||||
}
|
||||
if tt.wantOutputNil {
|
||||
assert.Nil(t, c.Output())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
15
internal/core/spec/types/doc.go
Normal file
15
internal/core/spec/types/doc.go
Normal file
@ -0,0 +1,15 @@
|
||||
// Package types provides typed union types for YAML fields that accept multiple formats.
|
||||
//
|
||||
// These types provide type-safe unmarshaling with early validation while maintaining
|
||||
// full backward compatibility with existing YAML files.
|
||||
//
|
||||
// Design principles:
|
||||
// 1. Each type captures all valid YAML representations for a field
|
||||
// 2. Validation happens at unmarshal time, not at build time
|
||||
// 3. Accessor methods provide type-safe access to the parsed values
|
||||
// 4. IsZero() indicates whether the field was set in YAML
|
||||
//
|
||||
// The types in this package are designed to be used as drop-in replacements
|
||||
// for `any` typed fields in the spec.definition struct, enabling gradual
|
||||
// migration while maintaining backward compatibility.
|
||||
package types
|
||||
114
internal/core/spec/types/env.go
Normal file
114
internal/core/spec/types/env.go
Normal file
@ -0,0 +1,114 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// EnvValue represents environment variable configuration that can be specified as:
|
||||
// - A map of key-value pairs
|
||||
// - An array of maps (for ordered definitions)
|
||||
// - An array of "KEY=value" strings
|
||||
// - A mix of maps and strings in an array
|
||||
//
|
||||
// YAML examples:
|
||||
//
|
||||
// env:
|
||||
// KEY1: value1
|
||||
// KEY2: value2
|
||||
//
|
||||
// env:
|
||||
// - KEY1: value1
|
||||
// - KEY2: value2
|
||||
//
|
||||
// env:
|
||||
// - KEY1=value1
|
||||
// - KEY2=value2
|
||||
type EnvValue struct {
|
||||
raw any // Original value for error reporting
|
||||
isSet bool // Whether the field was set in YAML
|
||||
entries []EnvEntry // Parsed entries in order
|
||||
}
|
||||
|
||||
// EnvEntry represents a single environment variable entry.
|
||||
type EnvEntry struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
|
||||
func (e *EnvValue) UnmarshalYAML(data []byte) error {
|
||||
e.isSet = true
|
||||
|
||||
var raw any
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("env unmarshal error: %w", err)
|
||||
}
|
||||
e.raw = raw
|
||||
|
||||
switch v := raw.(type) {
|
||||
case map[string]any:
|
||||
// Map of key-value pairs
|
||||
return e.parseMap(v)
|
||||
|
||||
case []any:
|
||||
// Array of maps or strings
|
||||
return e.parseArray(v)
|
||||
|
||||
case nil:
|
||||
e.isSet = false
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("env must be map or array, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EnvValue) parseMap(m map[string]any) error {
|
||||
for key, v := range m {
|
||||
value := stringifyValue(v)
|
||||
e.entries = append(e.entries, EnvEntry{Key: key, Value: value})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EnvValue) parseArray(arr []any) error {
|
||||
for i, item := range arr {
|
||||
switch v := item.(type) {
|
||||
case map[string]any:
|
||||
for key, val := range v {
|
||||
value := stringifyValue(val)
|
||||
e.entries = append(e.entries, EnvEntry{Key: key, Value: value})
|
||||
}
|
||||
case string:
|
||||
key, val, found := strings.Cut(v, "=")
|
||||
if !found {
|
||||
return fmt.Errorf("env[%d]: invalid format %q (expected KEY=value)", i, v)
|
||||
}
|
||||
e.entries = append(e.entries, EnvEntry{Key: key, Value: val})
|
||||
default:
|
||||
return fmt.Errorf("env[%d]: expected map or string, got %T", i, item)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func stringifyValue(v any) string {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val
|
||||
default:
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
}
|
||||
|
||||
// IsZero returns true if env was not set in YAML.
|
||||
func (e EnvValue) IsZero() bool { return !e.isSet }
|
||||
|
||||
// Value returns the original raw value for error reporting.
|
||||
func (e EnvValue) Value() any { return e.raw }
|
||||
|
||||
// Entries returns the parsed environment entries in order.
|
||||
func (e EnvValue) Entries() []EnvEntry { return e.entries }
|
||||
320
internal/core/spec/types/env_test.go
Normal file
320
internal/core/spec/types/env_test.go
Normal file
@ -0,0 +1,320 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnvValue_UnmarshalYAML(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantEntryCount int
|
||||
wantEntries map[string]string
|
||||
checkIsZero bool
|
||||
checkNotZero bool
|
||||
checkEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "MapForm",
|
||||
input: `
|
||||
KEY1: value1
|
||||
KEY2: value2
|
||||
`,
|
||||
wantEntryCount: 2,
|
||||
wantEntries: map[string]string{"KEY1": "value1", "KEY2": "value2"},
|
||||
},
|
||||
{
|
||||
name: "ArrayOfMapsPreservesOrder",
|
||||
input: `
|
||||
- KEY1: value1
|
||||
- KEY2: value2
|
||||
- KEY3: value3
|
||||
`,
|
||||
wantEntryCount: 3,
|
||||
wantEntries: map[string]string{"KEY1": "value1", "KEY2": "value2", "KEY3": "value3"},
|
||||
},
|
||||
{
|
||||
name: "ArrayOfStrings",
|
||||
input: `
|
||||
- KEY1=value1
|
||||
- KEY2=value2
|
||||
`,
|
||||
wantEntryCount: 2,
|
||||
wantEntries: map[string]string{"KEY1": "value1"},
|
||||
},
|
||||
{
|
||||
name: "MixedArrayMapsAndStrings",
|
||||
input: `
|
||||
- KEY1: value1
|
||||
- KEY2=value2
|
||||
- KEY3: value3
|
||||
`,
|
||||
wantEntryCount: 3,
|
||||
},
|
||||
{
|
||||
name: "NumericValuesStringified",
|
||||
input: `
|
||||
PORT: 8080
|
||||
ENABLED: true
|
||||
RATIO: 0.5
|
||||
`,
|
||||
wantEntryCount: 3,
|
||||
wantEntries: map[string]string{"PORT": "8080", "ENABLED": "true", "RATIO": "0.5"},
|
||||
},
|
||||
{
|
||||
name: "ValueWithEqualsSign",
|
||||
input: `
|
||||
- CONNECTION_STRING=host=localhost;port=5432
|
||||
`,
|
||||
wantEntryCount: 1,
|
||||
wantEntries: map[string]string{"CONNECTION_STRING": "host=localhost;port=5432"},
|
||||
},
|
||||
{
|
||||
name: "InvalidStringFormatNoEquals",
|
||||
input: `["invalid_no_equals"]`,
|
||||
wantErr: true,
|
||||
errContains: "expected KEY=value",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeScalarString",
|
||||
input: `"just a string"`,
|
||||
wantErr: true,
|
||||
errContains: "must be map or array",
|
||||
},
|
||||
{
|
||||
name: "EmptyMap",
|
||||
input: "{}",
|
||||
checkNotZero: true,
|
||||
checkEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyArray",
|
||||
input: "[]",
|
||||
checkNotZero: true,
|
||||
checkEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "EnvironmentVariableReference",
|
||||
input: `
|
||||
- PATH: ${HOME}/bin
|
||||
- DERIVED: ${OTHER_VAR}
|
||||
`,
|
||||
wantEntryCount: 2,
|
||||
wantEntries: map[string]string{"PATH": "${HOME}/bin", "DERIVED": "${OTHER_VAR}"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e types.EnvValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &e)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.checkIsZero {
|
||||
assert.True(t, e.IsZero())
|
||||
return
|
||||
}
|
||||
if tt.checkNotZero {
|
||||
assert.False(t, e.IsZero())
|
||||
}
|
||||
if tt.checkEmpty {
|
||||
assert.Empty(t, e.Entries())
|
||||
return
|
||||
}
|
||||
if tt.wantEntryCount > 0 {
|
||||
assert.Len(t, e.Entries(), tt.wantEntryCount)
|
||||
}
|
||||
if tt.wantEntries != nil {
|
||||
entries := e.Entries()
|
||||
keys := make(map[string]string)
|
||||
for _, entry := range entries {
|
||||
keys[entry.Key] = entry.Value
|
||||
}
|
||||
for k, v := range tt.wantEntries {
|
||||
assert.Equal(t, v, keys[k], "key %s should have value %s", k, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ZeroValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e types.EnvValue
|
||||
assert.True(t, e.IsZero())
|
||||
assert.Nil(t, e.Entries())
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnvValue_InStruct(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type StepConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Env types.EnvValue `yaml:"env"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantName string
|
||||
wantEntryCount int
|
||||
wantIsZero bool
|
||||
checkNotZero bool
|
||||
}{
|
||||
{
|
||||
name: "EnvSetAsMap",
|
||||
input: `
|
||||
name: my-step
|
||||
env:
|
||||
DEBUG: "true"
|
||||
LOG_LEVEL: info
|
||||
`,
|
||||
wantName: "my-step",
|
||||
wantEntryCount: 2,
|
||||
checkNotZero: true,
|
||||
},
|
||||
{
|
||||
name: "EnvSetAsArray",
|
||||
input: `
|
||||
name: my-step
|
||||
env:
|
||||
- DEBUG: "true"
|
||||
- LOG_LEVEL: info
|
||||
`,
|
||||
wantEntryCount: 2,
|
||||
},
|
||||
{
|
||||
name: "EnvNotSet",
|
||||
input: "name: my-step",
|
||||
wantIsZero: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cfg StepConfig
|
||||
err := yaml.Unmarshal([]byte(tt.input), &cfg)
|
||||
require.NoError(t, err)
|
||||
if tt.wantName != "" {
|
||||
assert.Equal(t, tt.wantName, cfg.Name)
|
||||
}
|
||||
if tt.wantEntryCount > 0 {
|
||||
assert.Len(t, cfg.Env.Entries(), tt.wantEntryCount)
|
||||
}
|
||||
if tt.wantIsZero {
|
||||
assert.True(t, cfg.Env.IsZero())
|
||||
}
|
||||
if tt.checkNotZero {
|
||||
assert.False(t, cfg.Env.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("EnvSetAsArrayPreservesOrder", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cfg StepConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
name: my-step
|
||||
env:
|
||||
- DEBUG: "true"
|
||||
- LOG_LEVEL: info
|
||||
`), &cfg)
|
||||
require.NoError(t, err)
|
||||
entries := cfg.Env.Entries()
|
||||
require.Len(t, entries, 2)
|
||||
assert.Equal(t, "DEBUG", entries[0].Key)
|
||||
assert.Equal(t, "LOG_LEVEL", entries[1].Key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnvValue_AdditionalCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
checkIsZero bool
|
||||
}{
|
||||
{
|
||||
name: "InvalidTypeInArrayNumber",
|
||||
input: "[123]",
|
||||
wantErr: true,
|
||||
errContains: "expected map or string",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeInArrayBoolean",
|
||||
input: "[true]",
|
||||
wantErr: true,
|
||||
errContains: "expected map or string",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeNumber",
|
||||
input: "123",
|
||||
wantErr: true,
|
||||
errContains: "must be map or array",
|
||||
},
|
||||
{
|
||||
name: "NullValue",
|
||||
input: "null",
|
||||
checkIsZero: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e types.EnvValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &e)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.checkIsZero {
|
||||
assert.True(t, e.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ValueReturnsRawMap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e types.EnvValue
|
||||
err := yaml.Unmarshal([]byte("KEY: value"), &e)
|
||||
require.NoError(t, err)
|
||||
val, ok := e.Value().(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "value", val["KEY"])
|
||||
})
|
||||
|
||||
t.Run("ValueReturnsRawArray", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var e types.EnvValue
|
||||
err := yaml.Unmarshal([]byte("[KEY=value]"), &e)
|
||||
require.NoError(t, err)
|
||||
val, ok := e.Value().([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, val, 1)
|
||||
})
|
||||
}
|
||||
77
internal/core/spec/types/logoutput.go
Normal file
77
internal/core/spec/types/logoutput.go
Normal file
@ -0,0 +1,77 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// LogOutputValue represents a log output configuration that can be unmarshaled from YAML.
|
||||
// It accepts a string value that must be one of: "separate" or "merged".
|
||||
// This type uses core.LogOutputMode to avoid type duplication.
|
||||
type LogOutputValue struct {
|
||||
mode core.LogOutputMode
|
||||
set bool // whether the value was explicitly set in YAML
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
|
||||
func (l *LogOutputValue) UnmarshalYAML(data []byte) error {
|
||||
l.set = true
|
||||
|
||||
var raw any
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("logOutput unmarshal error: %w", err)
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
value := strings.TrimSpace(strings.ToLower(v))
|
||||
switch value {
|
||||
case "separate", "":
|
||||
l.mode = core.LogOutputSeparate
|
||||
case "merged":
|
||||
l.mode = core.LogOutputMerged
|
||||
default:
|
||||
return fmt.Errorf("invalid logOutput value: %q (must be 'separate' or 'merged')", v)
|
||||
}
|
||||
return nil
|
||||
|
||||
case nil:
|
||||
l.set = false
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("logOutput must be a string, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// IsZero returns true if the value was not set in YAML.
|
||||
func (l LogOutputValue) IsZero() bool {
|
||||
return !l.set
|
||||
}
|
||||
|
||||
// Mode returns the log output mode.
|
||||
// If the value was not set, it returns core.LogOutputSeparate as the default.
|
||||
func (l LogOutputValue) Mode() core.LogOutputMode {
|
||||
if !l.set {
|
||||
return core.LogOutputSeparate
|
||||
}
|
||||
return l.mode
|
||||
}
|
||||
|
||||
// String returns the string representation of the log output mode.
|
||||
func (l LogOutputValue) String() string {
|
||||
return string(l.Mode())
|
||||
}
|
||||
|
||||
// IsMerged returns true if the log output mode is merged.
|
||||
func (l LogOutputValue) IsMerged() bool {
|
||||
return l.mode == core.LogOutputMerged
|
||||
}
|
||||
|
||||
// IsSeparate returns true if the log output mode is separate.
|
||||
func (l LogOutputValue) IsSeparate() bool {
|
||||
return !l.set || l.mode == core.LogOutputSeparate
|
||||
}
|
||||
163
internal/core/spec/types/logoutput_test.go
Normal file
163
internal/core/spec/types/logoutput_test.go
Normal file
@ -0,0 +1,163 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLogOutputValue_UnmarshalYAML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantMode core.LogOutputMode
|
||||
wantSet bool
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "separate mode",
|
||||
input: "logOutput: separate",
|
||||
wantMode: core.LogOutputSeparate,
|
||||
wantSet: true,
|
||||
},
|
||||
{
|
||||
name: "merged mode",
|
||||
input: "logOutput: merged",
|
||||
wantMode: core.LogOutputMerged,
|
||||
wantSet: true,
|
||||
},
|
||||
{
|
||||
name: "merged mode uppercase",
|
||||
input: "logOutput: MERGED",
|
||||
wantMode: core.LogOutputMerged,
|
||||
wantSet: true,
|
||||
},
|
||||
{
|
||||
name: "separate mode mixed case",
|
||||
input: "logOutput: Separate",
|
||||
wantMode: core.LogOutputSeparate,
|
||||
wantSet: true,
|
||||
},
|
||||
{
|
||||
name: "empty string defaults to separate",
|
||||
input: "logOutput: ''",
|
||||
wantMode: core.LogOutputSeparate,
|
||||
wantSet: true,
|
||||
},
|
||||
{
|
||||
name: "invalid value",
|
||||
input: "logOutput: invalid",
|
||||
wantErr: true,
|
||||
errContains: "invalid logOutput value",
|
||||
},
|
||||
{
|
||||
name: "invalid value - both",
|
||||
input: "logOutput: both",
|
||||
wantErr: true,
|
||||
errContains: "invalid logOutput value",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result struct {
|
||||
LogOutput LogOutputValue `yaml:"logOutput"`
|
||||
}
|
||||
|
||||
err := yaml.Unmarshal([]byte(tt.input), &result)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantMode, result.LogOutput.Mode())
|
||||
assert.Equal(t, tt.wantSet, !result.LogOutput.IsZero())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogOutputValue_UnmarshalYAML_InvalidType(t *testing.T) {
|
||||
var result struct {
|
||||
LogOutput LogOutputValue `yaml:"logOutput"`
|
||||
}
|
||||
|
||||
err := yaml.Unmarshal([]byte("logOutput:\n - item1\n - item2"), &result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must be a string")
|
||||
}
|
||||
|
||||
func TestLogOutputValue_DefaultValue(t *testing.T) {
|
||||
var result struct {
|
||||
LogOutput LogOutputValue `yaml:"logOutput"`
|
||||
}
|
||||
|
||||
// When not set in YAML
|
||||
err := yaml.Unmarshal([]byte("other: value"), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be zero
|
||||
assert.True(t, result.LogOutput.IsZero())
|
||||
// Default mode should be separate
|
||||
assert.Equal(t, core.LogOutputSeparate, result.LogOutput.Mode())
|
||||
assert.True(t, result.LogOutput.IsSeparate())
|
||||
assert.False(t, result.LogOutput.IsMerged())
|
||||
}
|
||||
|
||||
func TestLogOutputValue_Methods(t *testing.T) {
|
||||
t.Run("IsMerged", func(t *testing.T) {
|
||||
merged := LogOutputValue{mode: core.LogOutputMerged, set: true}
|
||||
assert.True(t, merged.IsMerged())
|
||||
assert.False(t, merged.IsSeparate())
|
||||
|
||||
separate := LogOutputValue{mode: core.LogOutputSeparate, set: true}
|
||||
assert.False(t, separate.IsMerged())
|
||||
assert.True(t, separate.IsSeparate())
|
||||
})
|
||||
|
||||
t.Run("String", func(t *testing.T) {
|
||||
merged := LogOutputValue{mode: core.LogOutputMerged, set: true}
|
||||
assert.Equal(t, "merged", merged.String())
|
||||
|
||||
separate := LogOutputValue{mode: core.LogOutputSeparate, set: true}
|
||||
assert.Equal(t, "separate", separate.String())
|
||||
|
||||
unset := LogOutputValue{}
|
||||
assert.Equal(t, "separate", unset.String()) // default
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogOutputMode_Constants(t *testing.T) {
|
||||
// Ensure constants have expected values
|
||||
assert.Equal(t, core.LogOutputMode("separate"), core.LogOutputSeparate)
|
||||
assert.Equal(t, core.LogOutputMode("merged"), core.LogOutputMerged)
|
||||
}
|
||||
|
||||
func TestLogOutputValue_UnmarshalYAML_NilValue(t *testing.T) {
|
||||
var result struct {
|
||||
LogOutput LogOutputValue `yaml:"logOutput"`
|
||||
}
|
||||
|
||||
// Explicit null value in YAML
|
||||
err := yaml.Unmarshal([]byte("logOutput: null"), &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.LogOutput.IsZero())
|
||||
assert.Equal(t, core.LogOutputSeparate, result.LogOutput.Mode())
|
||||
}
|
||||
|
||||
func TestLogOutputValue_UnmarshalYAML_MapValue(t *testing.T) {
|
||||
var result struct {
|
||||
LogOutput LogOutputValue `yaml:"logOutput"`
|
||||
}
|
||||
|
||||
// Map value should fail with "must be a string" error
|
||||
err := yaml.Unmarshal([]byte("logOutput:\n key: value"), &result)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must be a string")
|
||||
}
|
||||
68
internal/core/spec/types/port.go
Normal file
68
internal/core/spec/types/port.go
Normal file
@ -0,0 +1,68 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// PortValue represents a port number that can be specified as either
|
||||
// a string or an integer.
|
||||
//
|
||||
// YAML examples:
|
||||
//
|
||||
// port: 22
|
||||
// port: "22"
|
||||
type PortValue struct {
|
||||
raw any // Original value for error reporting
|
||||
isSet bool // Whether the field was set in YAML
|
||||
value string // Normalized string value
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
|
||||
func (p *PortValue) UnmarshalYAML(data []byte) error {
|
||||
p.isSet = true
|
||||
|
||||
var raw any
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("port unmarshal error: %w", err)
|
||||
}
|
||||
p.raw = raw
|
||||
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
p.value = v
|
||||
return nil
|
||||
|
||||
case int:
|
||||
p.value = fmt.Sprintf("%d", v)
|
||||
return nil
|
||||
|
||||
case float64:
|
||||
if v != float64(int(v)) {
|
||||
return fmt.Errorf("port must be an integer, got %v", v)
|
||||
}
|
||||
p.value = fmt.Sprintf("%d", int(v))
|
||||
return nil
|
||||
|
||||
case uint64:
|
||||
p.value = fmt.Sprintf("%d", v)
|
||||
return nil
|
||||
|
||||
case nil:
|
||||
p.isSet = false
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("port must be string or number, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// IsZero returns true if the port was not set in YAML.
|
||||
func (p PortValue) IsZero() bool { return !p.isSet }
|
||||
|
||||
// Value returns the original raw value for error reporting.
|
||||
func (p PortValue) Value() any { return p.raw }
|
||||
|
||||
// String returns the port as a string.
|
||||
func (p PortValue) String() string { return p.value }
|
||||
231
internal/core/spec/types/port_test.go
Normal file
231
internal/core/spec/types/port_test.go
Normal file
@ -0,0 +1,231 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPortValue_UnmarshalYAML(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantString string
|
||||
checkNotZero bool
|
||||
}{
|
||||
{
|
||||
name: "Integer",
|
||||
input: "22",
|
||||
wantString: "22",
|
||||
checkNotZero: true,
|
||||
},
|
||||
{
|
||||
name: "String",
|
||||
input: `"8080"`,
|
||||
wantString: "8080",
|
||||
},
|
||||
{
|
||||
name: "LargePortNumber",
|
||||
input: "65535",
|
||||
wantString: "65535",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeArray",
|
||||
input: "[22, 80]",
|
||||
wantErr: true,
|
||||
errContains: "must be string or number",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeMap",
|
||||
input: "{port: 22}",
|
||||
wantErr: true,
|
||||
errContains: "must be string or number",
|
||||
},
|
||||
{
|
||||
name: "FloatWithDecimal",
|
||||
input: "22.5",
|
||||
wantErr: true,
|
||||
errContains: "port must be an integer",
|
||||
},
|
||||
{
|
||||
name: "FloatWholeNumber",
|
||||
input: "22.0",
|
||||
wantString: "22",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var p types.PortValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &p)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantString, p.String())
|
||||
if tt.checkNotZero {
|
||||
assert.False(t, p.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ZeroValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var p types.PortValue
|
||||
assert.True(t, p.IsZero())
|
||||
assert.Equal(t, "", p.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPortValue_InStruct(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type SSHConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port types.PortValue `yaml:"port"`
|
||||
User string `yaml:"user"`
|
||||
}
|
||||
|
||||
sshTests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantHost string
|
||||
wantPort string
|
||||
wantUser string
|
||||
wantIsZero bool
|
||||
}{
|
||||
{
|
||||
name: "PortAsInteger",
|
||||
input: `
|
||||
host: example.com
|
||||
port: 22
|
||||
user: admin
|
||||
`,
|
||||
wantHost: "example.com",
|
||||
wantPort: "22",
|
||||
wantUser: "admin",
|
||||
},
|
||||
{
|
||||
name: "PortAsString",
|
||||
input: `
|
||||
host: example.com
|
||||
port: "2222"
|
||||
user: admin
|
||||
`,
|
||||
wantPort: "2222",
|
||||
},
|
||||
{
|
||||
name: "PortNotSet",
|
||||
input: `
|
||||
host: example.com
|
||||
user: admin
|
||||
`,
|
||||
wantIsZero: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range sshTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cfg SSHConfig
|
||||
err := yaml.Unmarshal([]byte(tt.input), &cfg)
|
||||
require.NoError(t, err)
|
||||
if tt.wantHost != "" {
|
||||
assert.Equal(t, tt.wantHost, cfg.Host)
|
||||
}
|
||||
if tt.wantPort != "" {
|
||||
assert.Equal(t, tt.wantPort, cfg.Port.String())
|
||||
}
|
||||
if tt.wantUser != "" {
|
||||
assert.Equal(t, tt.wantUser, cfg.User)
|
||||
}
|
||||
if tt.wantIsZero {
|
||||
assert.True(t, cfg.Port.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("SMTPPort", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
type SMTPConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port types.PortValue `yaml:"port"`
|
||||
Username string `yaml:"username"`
|
||||
}
|
||||
var cfg SMTPConfig
|
||||
err := yaml.Unmarshal([]byte(`
|
||||
host: smtp.example.com
|
||||
port: 587
|
||||
username: user
|
||||
`), &cfg)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "587", cfg.Port.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPortValue_AdditionalCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantValueNotNil bool
|
||||
wantValue any
|
||||
wantIsZero bool
|
||||
wantString string
|
||||
}{
|
||||
{
|
||||
name: "ValueReturnsRawInt",
|
||||
input: "22",
|
||||
wantValueNotNil: true,
|
||||
},
|
||||
{
|
||||
name: "ValueReturnsRawString",
|
||||
input: `"22"`,
|
||||
wantValue: "22",
|
||||
},
|
||||
{
|
||||
name: "NullValueSetsIsZeroFalse",
|
||||
input: "null",
|
||||
wantIsZero: true,
|
||||
},
|
||||
{
|
||||
name: "LargeInteger",
|
||||
input: "99999",
|
||||
wantString: "99999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var p types.PortValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &p)
|
||||
require.NoError(t, err)
|
||||
if tt.wantValueNotNil {
|
||||
assert.NotNil(t, p.Value())
|
||||
}
|
||||
if tt.wantValue != nil {
|
||||
assert.Equal(t, tt.wantValue, p.Value())
|
||||
}
|
||||
if tt.wantIsZero {
|
||||
assert.True(t, p.IsZero())
|
||||
}
|
||||
if tt.wantString != "" {
|
||||
assert.Equal(t, tt.wantString, p.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
137
internal/core/spec/types/schedule.go
Normal file
137
internal/core/spec/types/schedule.go
Normal file
@ -0,0 +1,137 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// ScheduleValue represents a schedule configuration that can be specified as:
|
||||
// - A single cron expression string
|
||||
// - An array of cron expressions
|
||||
// - A map with start/stop/restart keys
|
||||
//
|
||||
// YAML examples:
|
||||
//
|
||||
// schedule: "0 * * * *"
|
||||
// schedule: ["0 * * * *", "30 * * * *"]
|
||||
// schedule:
|
||||
// start: "0 8 * * *"
|
||||
// stop: "0 18 * * *"
|
||||
// restart: "0 12 * * *"
|
||||
type ScheduleValue struct {
|
||||
raw any // Original value for error reporting
|
||||
isSet bool // Whether the field was set in YAML
|
||||
starts []string // Start schedules (or simple schedule expressions)
|
||||
stops []string // Stop schedules
|
||||
restarts []string // Restart schedules
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
|
||||
func (s *ScheduleValue) UnmarshalYAML(data []byte) error {
|
||||
s.isSet = true
|
||||
|
||||
var raw any
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("schedule unmarshal error: %w", err)
|
||||
}
|
||||
s.raw = raw
|
||||
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
// Single cron expression
|
||||
if v != "" {
|
||||
s.starts = []string{v}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
// Array of cron expressions
|
||||
for i, item := range v {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("schedule[%d]: expected string, got %T", i, item)
|
||||
}
|
||||
s.starts = append(s.starts, str)
|
||||
}
|
||||
return nil
|
||||
|
||||
case []string:
|
||||
// Array of strings (from Go types)
|
||||
s.starts = v
|
||||
return nil
|
||||
|
||||
case map[string]any:
|
||||
// Map with start/stop/restart keys
|
||||
return s.parseScheduleMap(v)
|
||||
|
||||
case nil:
|
||||
// nil is valid, just means not set
|
||||
s.isSet = false
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("schedule must be string, array, or map, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScheduleValue) parseScheduleMap(m map[string]any) error {
|
||||
for key, v := range m {
|
||||
values, err := parseScheduleEntry(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("schedule.%s: %w", key, err)
|
||||
}
|
||||
|
||||
switch key {
|
||||
case "start":
|
||||
s.starts = values
|
||||
case "stop":
|
||||
s.stops = values
|
||||
case "restart":
|
||||
s.restarts = values
|
||||
default:
|
||||
return fmt.Errorf("schedule: unknown key %q (expected start, stop, or restart)", key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseScheduleEntry(v any) ([]string, error) {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return []string{val}, nil
|
||||
case []any:
|
||||
var result []string
|
||||
for i, item := range val {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("[%d]: expected string, got %T", i, item)
|
||||
}
|
||||
result = append(result, str)
|
||||
}
|
||||
return result, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("expected string or array, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// IsZero returns true if the schedule was not set in YAML.
|
||||
func (s ScheduleValue) IsZero() bool { return !s.isSet }
|
||||
|
||||
// Value returns the original raw value for error reporting.
|
||||
func (s ScheduleValue) Value() any { return s.raw }
|
||||
|
||||
// Starts returns the start/simple schedules.
|
||||
func (s ScheduleValue) Starts() []string { return s.starts }
|
||||
|
||||
// Stops returns the stop schedules.
|
||||
func (s ScheduleValue) Stops() []string { return s.stops }
|
||||
|
||||
// Restarts returns the restart schedules.
|
||||
func (s ScheduleValue) Restarts() []string { return s.restarts }
|
||||
|
||||
// HasStopSchedule returns true if stop schedules are configured.
|
||||
func (s ScheduleValue) HasStopSchedule() bool { return len(s.stops) > 0 }
|
||||
|
||||
// HasRestartSchedule returns true if restart schedules are configured.
|
||||
func (s ScheduleValue) HasRestartSchedule() bool { return len(s.restarts) > 0 }
|
||||
305
internal/core/spec/types/schedule_test.go
Normal file
305
internal/core/spec/types/schedule_test.go
Normal file
@ -0,0 +1,305 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestScheduleValue_UnmarshalYAML(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantStarts []string
|
||||
wantStops []string
|
||||
wantRestarts []string
|
||||
wantHasStop bool
|
||||
wantHasRestart bool
|
||||
checkHasStop bool
|
||||
checkHasRestart bool
|
||||
}{
|
||||
{
|
||||
name: "SingleCronExpression",
|
||||
input: `"0 * * * *"`,
|
||||
wantStarts: []string{"0 * * * *"},
|
||||
checkHasStop: true,
|
||||
wantHasStop: false,
|
||||
checkHasRestart: true,
|
||||
wantHasRestart: false,
|
||||
},
|
||||
{
|
||||
name: "ArrayOfCronExpressions",
|
||||
input: `["0 * * * *", "30 * * * *"]`,
|
||||
wantStarts: []string{"0 * * * *", "30 * * * *"},
|
||||
},
|
||||
{
|
||||
name: "MultilineArray",
|
||||
input: `
|
||||
- "0 8 * * *"
|
||||
- "0 12 * * *"
|
||||
- "0 18 * * *"
|
||||
`,
|
||||
wantStarts: []string{"0 8 * * *", "0 12 * * *", "0 18 * * *"},
|
||||
},
|
||||
{
|
||||
name: "MapWithStartOnly",
|
||||
input: `start: "0 8 * * *"`,
|
||||
wantStarts: []string{"0 8 * * *"},
|
||||
checkHasStop: true,
|
||||
wantHasStop: false,
|
||||
},
|
||||
{
|
||||
name: "MapWithStartAndStop",
|
||||
input: `
|
||||
start: "0 8 * * *"
|
||||
stop: "0 18 * * *"
|
||||
`,
|
||||
wantStarts: []string{"0 8 * * *"},
|
||||
wantStops: []string{"0 18 * * *"},
|
||||
checkHasStop: true,
|
||||
wantHasStop: true,
|
||||
},
|
||||
{
|
||||
name: "MapWithAllKeys",
|
||||
input: `
|
||||
start: "0 8 * * *"
|
||||
stop: "0 18 * * *"
|
||||
restart: "0 0 * * *"
|
||||
`,
|
||||
wantStarts: []string{"0 8 * * *"},
|
||||
wantStops: []string{"0 18 * * *"},
|
||||
wantRestarts: []string{"0 0 * * *"},
|
||||
checkHasRestart: true,
|
||||
wantHasRestart: true,
|
||||
},
|
||||
{
|
||||
name: "MapWithArrayValues",
|
||||
input: `
|
||||
start:
|
||||
- "0 8 * * *"
|
||||
- "0 12 * * *"
|
||||
stop: "0 18 * * *"
|
||||
`,
|
||||
wantStarts: []string{"0 8 * * *", "0 12 * * *"},
|
||||
wantStops: []string{"0 18 * * *"},
|
||||
},
|
||||
{
|
||||
name: "InvalidMapKey",
|
||||
input: `invalid: "0 * * * *"`,
|
||||
wantErr: true,
|
||||
errContains: "unknown key",
|
||||
},
|
||||
{
|
||||
name: "InvalidArrayElementType",
|
||||
input: `
|
||||
start:
|
||||
- 123
|
||||
- 456
|
||||
`,
|
||||
wantErr: true,
|
||||
errContains: "expected string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ScheduleValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &s)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.wantStarts != nil {
|
||||
assert.Equal(t, tt.wantStarts, s.Starts())
|
||||
}
|
||||
if tt.wantStops != nil {
|
||||
assert.Equal(t, tt.wantStops, s.Stops())
|
||||
}
|
||||
if tt.wantRestarts != nil {
|
||||
assert.Equal(t, tt.wantRestarts, s.Restarts())
|
||||
}
|
||||
if tt.checkHasStop {
|
||||
assert.Equal(t, tt.wantHasStop, s.HasStopSchedule())
|
||||
}
|
||||
if tt.checkHasRestart {
|
||||
assert.Equal(t, tt.wantHasRestart, s.HasRestartSchedule())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ZeroValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ScheduleValue
|
||||
assert.True(t, s.IsZero())
|
||||
assert.Nil(t, s.Starts())
|
||||
})
|
||||
}
|
||||
|
||||
func TestScheduleValue_InStruct(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type DAGConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Schedule types.ScheduleValue `yaml:"schedule"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantName string
|
||||
wantStarts []string
|
||||
wantStops []string
|
||||
wantIsZero bool
|
||||
}{
|
||||
{
|
||||
name: "SimpleSchedule",
|
||||
input: `
|
||||
name: my-dag
|
||||
schedule: "0 * * * *"
|
||||
`,
|
||||
wantName: "my-dag",
|
||||
wantStarts: []string{"0 * * * *"},
|
||||
},
|
||||
{
|
||||
name: "ComplexSchedule",
|
||||
input: `
|
||||
name: my-dag
|
||||
schedule:
|
||||
start: "0 8 * * 1-5"
|
||||
stop: "0 18 * * 1-5"
|
||||
`,
|
||||
wantStarts: []string{"0 8 * * 1-5"},
|
||||
wantStops: []string{"0 18 * * 1-5"},
|
||||
},
|
||||
{
|
||||
name: "NoSchedule",
|
||||
input: "name: my-dag",
|
||||
wantIsZero: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cfg DAGConfig
|
||||
err := yaml.Unmarshal([]byte(tt.input), &cfg)
|
||||
require.NoError(t, err)
|
||||
if tt.wantName != "" {
|
||||
assert.Equal(t, tt.wantName, cfg.Name)
|
||||
}
|
||||
if tt.wantStarts != nil {
|
||||
assert.Equal(t, tt.wantStarts, cfg.Schedule.Starts())
|
||||
}
|
||||
if tt.wantStops != nil {
|
||||
assert.Equal(t, tt.wantStops, cfg.Schedule.Stops())
|
||||
}
|
||||
if tt.wantIsZero {
|
||||
assert.True(t, cfg.Schedule.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheduleValue_AdditionalCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantValue any
|
||||
checkIsZero bool
|
||||
wantStarts []string
|
||||
}{
|
||||
{
|
||||
name: "ValueReturnsRawString",
|
||||
input: `"0 * * * *"`,
|
||||
wantValue: "0 * * * *",
|
||||
},
|
||||
{
|
||||
name: "NullValue",
|
||||
input: "null",
|
||||
checkIsZero: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: `""`,
|
||||
wantStarts: nil,
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeNumber",
|
||||
input: "123",
|
||||
wantErr: true,
|
||||
errContains: "must be string, array, or map",
|
||||
},
|
||||
{
|
||||
name: "InvalidScheduleEntryTypeInMap",
|
||||
input: "start: 123",
|
||||
wantErr: true,
|
||||
errContains: "expected string or array",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeInStartArray",
|
||||
input: `["0 * * * *", 123]`,
|
||||
wantErr: true,
|
||||
errContains: "expected string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ScheduleValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &s)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.wantValue != nil {
|
||||
assert.Equal(t, tt.wantValue, s.Value())
|
||||
}
|
||||
if tt.checkIsZero {
|
||||
assert.True(t, s.IsZero())
|
||||
}
|
||||
if tt.wantStarts != nil {
|
||||
assert.Equal(t, tt.wantStarts, s.Starts())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ValueReturnsRawArray", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ScheduleValue
|
||||
err := yaml.Unmarshal([]byte(`["0 * * * *"]`), &s)
|
||||
require.NoError(t, err)
|
||||
val, ok := s.Value().([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, val, 1)
|
||||
})
|
||||
|
||||
t.Run("EmptyStringNotZero", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ScheduleValue
|
||||
err := yaml.Unmarshal([]byte(`""`), &s)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, s.IsZero())
|
||||
assert.Nil(t, s.Starts())
|
||||
})
|
||||
}
|
||||
99
internal/core/spec/types/shell.go
Normal file
99
internal/core/spec/types/shell.go
Normal file
@ -0,0 +1,99 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// ShellValue represents a shell configuration that can be specified as either
|
||||
// a string (e.g., "bash -e") or an array (e.g., ["bash", "-e"]).
|
||||
//
|
||||
// YAML examples:
|
||||
//
|
||||
// shell: "bash -e"
|
||||
// shell: bash
|
||||
// shell: ["bash", "-e"]
|
||||
// shell:
|
||||
// - bash
|
||||
// - -e
|
||||
type ShellValue struct {
|
||||
raw any // Original value for error reporting
|
||||
isSet bool // Whether the field was set in YAML
|
||||
command string // The shell command (first element)
|
||||
arguments []string // Additional arguments
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
|
||||
func (s *ShellValue) UnmarshalYAML(data []byte) error {
|
||||
s.isSet = true
|
||||
|
||||
var raw any
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("shell unmarshal error: %w", err)
|
||||
}
|
||||
s.raw = raw
|
||||
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
// String value: "bash -e" or "bash"
|
||||
s.command = strings.TrimSpace(v)
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
// Array value: ["bash", "-e"]
|
||||
if len(v) > 0 {
|
||||
if cmd, ok := v[0].(string); ok {
|
||||
s.command = cmd
|
||||
} else {
|
||||
s.command = fmt.Sprintf("%v", v[0])
|
||||
}
|
||||
for i := 1; i < len(v); i++ {
|
||||
if arg, ok := v[i].(string); ok {
|
||||
s.arguments = append(s.arguments, arg)
|
||||
} else {
|
||||
s.arguments = append(s.arguments, fmt.Sprintf("%v", v[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []string:
|
||||
// Array of strings (from Go types)
|
||||
if len(v) > 0 {
|
||||
s.command = v[0]
|
||||
s.arguments = v[1:]
|
||||
}
|
||||
return nil
|
||||
|
||||
case nil:
|
||||
s.isSet = false
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("shell must be string or array, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// IsZero returns true if the shell value was not set in YAML.
|
||||
func (s ShellValue) IsZero() bool { return !s.isSet }
|
||||
|
||||
// Value returns the original raw value for error reporting.
|
||||
func (s ShellValue) Value() any { return s.raw }
|
||||
|
||||
// Command returns the shell command (first element or full string).
|
||||
func (s ShellValue) Command() string { return s.command }
|
||||
|
||||
// Arguments returns the additional shell arguments (only set for array form).
|
||||
func (s ShellValue) Arguments() []string { return s.arguments }
|
||||
|
||||
// IsArray returns true if the value was specified as an array.
|
||||
func (s ShellValue) IsArray() bool {
|
||||
switch s.raw.(type) {
|
||||
case []any, []string:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
269
internal/core/spec/types/shell_test.go
Normal file
269
internal/core/spec/types/shell_test.go
Normal file
@ -0,0 +1,269 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShellValue_UnmarshalYAML(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantCommand string
|
||||
wantArgs []string
|
||||
wantIsArray bool
|
||||
checkNotZero bool
|
||||
checkIsZero bool
|
||||
}{
|
||||
{
|
||||
name: "StringWithoutArgs",
|
||||
input: "bash",
|
||||
wantCommand: "bash",
|
||||
wantArgs: nil,
|
||||
wantIsArray: false,
|
||||
checkNotZero: true,
|
||||
},
|
||||
{
|
||||
name: "StringWithArgsInline",
|
||||
input: `"bash -e"`,
|
||||
wantCommand: "bash -e",
|
||||
wantArgs: nil,
|
||||
},
|
||||
{
|
||||
name: "ArrayFormInline",
|
||||
input: `["bash", "-e", "-x"]`,
|
||||
wantCommand: "bash",
|
||||
wantArgs: []string{"-e", "-x"},
|
||||
wantIsArray: true,
|
||||
},
|
||||
{
|
||||
name: "MultilineArrayForm",
|
||||
input: "- bash\n- -e\n- -x",
|
||||
wantCommand: "bash",
|
||||
wantArgs: []string{"-e", "-x"},
|
||||
},
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: `""`,
|
||||
wantCommand: "",
|
||||
checkNotZero: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyArray",
|
||||
input: "[]",
|
||||
wantCommand: "",
|
||||
wantArgs: nil,
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeMap",
|
||||
input: "{key: value}",
|
||||
wantErr: true,
|
||||
errContains: "must be string or array",
|
||||
},
|
||||
{
|
||||
name: "ShellWithEnvVariableSyntax",
|
||||
input: `"${SHELL}"`,
|
||||
wantCommand: "${SHELL}",
|
||||
},
|
||||
{
|
||||
name: "NixShellExample",
|
||||
input: `["nix-shell", "-p", "python3"]`,
|
||||
wantCommand: "nix-shell",
|
||||
wantArgs: []string{"-p", "python3"},
|
||||
},
|
||||
{
|
||||
name: "NullValue",
|
||||
input: "null",
|
||||
checkIsZero: true,
|
||||
},
|
||||
{
|
||||
name: "ArrayWithNonStringItems",
|
||||
input: "[123, true]",
|
||||
wantCommand: "123",
|
||||
wantArgs: []string{"true"},
|
||||
},
|
||||
{
|
||||
name: "SingleElementArray",
|
||||
input: `["bash"]`,
|
||||
wantCommand: "bash",
|
||||
wantArgs: nil,
|
||||
wantIsArray: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ShellValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &s)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.checkIsZero {
|
||||
assert.True(t, s.IsZero())
|
||||
return
|
||||
}
|
||||
if tt.checkNotZero {
|
||||
assert.False(t, s.IsZero())
|
||||
}
|
||||
assert.Equal(t, tt.wantCommand, s.Command())
|
||||
if tt.wantArgs != nil {
|
||||
assert.Equal(t, tt.wantArgs, s.Arguments())
|
||||
} else {
|
||||
assert.Empty(t, s.Arguments())
|
||||
}
|
||||
if tt.wantIsArray {
|
||||
assert.True(t, s.IsArray())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ZeroValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ShellValue
|
||||
assert.True(t, s.IsZero())
|
||||
})
|
||||
}
|
||||
|
||||
func TestShellValue_InStruct(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type Config struct {
|
||||
Shell types.ShellValue `yaml:"shell"`
|
||||
Name string `yaml:"name"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantName string
|
||||
wantCommand string
|
||||
wantIsZero bool
|
||||
checkNotZero bool
|
||||
}{
|
||||
{
|
||||
name: "ShellSet",
|
||||
input: `
|
||||
name: test
|
||||
shell: bash
|
||||
`,
|
||||
wantName: "test",
|
||||
wantCommand: "bash",
|
||||
checkNotZero: true,
|
||||
},
|
||||
{
|
||||
name: "ShellNotSet",
|
||||
input: "name: test",
|
||||
wantIsZero: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cfg Config
|
||||
err := yaml.Unmarshal([]byte(tt.input), &cfg)
|
||||
require.NoError(t, err)
|
||||
if tt.wantName != "" {
|
||||
assert.Equal(t, tt.wantName, cfg.Name)
|
||||
}
|
||||
if tt.wantCommand != "" {
|
||||
assert.Equal(t, tt.wantCommand, cfg.Shell.Command())
|
||||
}
|
||||
if tt.wantIsZero {
|
||||
assert.True(t, cfg.Shell.IsZero())
|
||||
}
|
||||
if tt.checkNotZero {
|
||||
assert.False(t, cfg.Shell.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellValue_AdditionalCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantValue any
|
||||
checkIsArray bool
|
||||
wantIsArray bool
|
||||
checkIsZero bool
|
||||
}{
|
||||
{
|
||||
name: "ValueReturnsRawString",
|
||||
input: "bash",
|
||||
wantValue: "bash",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeNumber",
|
||||
input: "123",
|
||||
wantErr: true,
|
||||
errContains: "must be string or array",
|
||||
},
|
||||
{
|
||||
name: "IsArrayReturnsFalseForString",
|
||||
input: `"bash -e"`,
|
||||
checkIsArray: true,
|
||||
wantIsArray: false,
|
||||
},
|
||||
{
|
||||
name: "IsArrayReturnsFalseForNil",
|
||||
input: "null",
|
||||
checkIsArray: true,
|
||||
wantIsArray: false,
|
||||
checkIsZero: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ShellValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &s)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.wantValue != nil {
|
||||
assert.Equal(t, tt.wantValue, s.Value())
|
||||
}
|
||||
if tt.checkIsArray {
|
||||
assert.Equal(t, tt.wantIsArray, s.IsArray())
|
||||
}
|
||||
if tt.checkIsZero {
|
||||
assert.True(t, s.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ValueReturnsRawArray", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.ShellValue
|
||||
err := yaml.Unmarshal([]byte(`["bash", "-e"]`), &s)
|
||||
require.NoError(t, err)
|
||||
val, ok := s.Value().([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, val, 2)
|
||||
})
|
||||
}
|
||||
90
internal/core/spec/types/stringarray.go
Normal file
90
internal/core/spec/types/stringarray.go
Normal file
@ -0,0 +1,90 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// StringOrArray represents a value that can be specified as either a single
|
||||
// string or an array of strings.
|
||||
//
|
||||
// YAML examples:
|
||||
//
|
||||
// depends: "step1"
|
||||
// depends: ["step1", "step2"]
|
||||
// dotenv: ".env"
|
||||
// dotenv: [".env", ".env.local"]
|
||||
type StringOrArray struct {
|
||||
raw any // Original value for error reporting
|
||||
isSet bool // Whether the field was set in YAML
|
||||
values []string // Parsed values
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements BytesUnmarshaler for goccy/go-yaml.
|
||||
func (s *StringOrArray) UnmarshalYAML(data []byte) error {
|
||||
s.isSet = true
|
||||
|
||||
var raw any
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("unmarshal error: %w", err)
|
||||
}
|
||||
s.raw = raw
|
||||
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
// Single string value - preserve empty strings for validation layer to handle
|
||||
s.values = []string{v}
|
||||
return nil
|
||||
|
||||
case []any:
|
||||
// Array of values - convert non-strings to strings for compatibility
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
s.values = append(s.values, str)
|
||||
} else {
|
||||
// Stringify non-string items (e.g., numeric tags)
|
||||
s.values = append(s.values, fmt.Sprintf("%v", item))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
case []string:
|
||||
// Array of strings (from Go types)
|
||||
s.values = v
|
||||
return nil
|
||||
|
||||
case nil:
|
||||
s.isSet = false
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("must be string or array, got %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// IsZero returns true if the value was not set in YAML.
|
||||
func (s StringOrArray) IsZero() bool { return !s.isSet }
|
||||
|
||||
// Value returns the original raw value for error reporting.
|
||||
func (s StringOrArray) Value() any { return s.raw }
|
||||
|
||||
// Values returns the parsed string values.
|
||||
func (s StringOrArray) Values() []string { return s.values }
|
||||
|
||||
// IsEmpty returns true if set but contains no values (empty array).
|
||||
func (s StringOrArray) IsEmpty() bool { return s.isSet && len(s.values) == 0 }
|
||||
|
||||
// MailToValue is an alias for StringOrArray used for email recipients.
|
||||
// YAML examples:
|
||||
//
|
||||
// to: user@example.com
|
||||
// to: ["user1@example.com", "user2@example.com"]
|
||||
type MailToValue = StringOrArray
|
||||
|
||||
// TagsValue is an alias for StringOrArray used for tags.
|
||||
// YAML examples:
|
||||
//
|
||||
// tags: production
|
||||
// tags: ["production", "critical"]
|
||||
type TagsValue = StringOrArray
|
||||
309
internal/core/spec/types/stringarray_test.go
Normal file
309
internal/core/spec/types/stringarray_test.go
Normal file
@ -0,0 +1,309 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStringOrArray_UnmarshalYAML(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantValues []string
|
||||
checkIsEmpty bool
|
||||
wantIsEmpty bool
|
||||
checkNotZero bool
|
||||
}{
|
||||
{
|
||||
name: "SingleString",
|
||||
input: "step1",
|
||||
wantValues: []string{"step1"},
|
||||
checkNotZero: true,
|
||||
},
|
||||
{
|
||||
name: "ArrayOfStringsInline",
|
||||
input: `["step1", "step2", "step3"]`,
|
||||
wantValues: []string{"step1", "step2", "step3"},
|
||||
},
|
||||
{
|
||||
name: "MultilineArray",
|
||||
input: "- step1\n- step2",
|
||||
wantValues: []string{"step1", "step2"},
|
||||
},
|
||||
{
|
||||
name: "EmptyString",
|
||||
input: `""`,
|
||||
wantValues: []string{""},
|
||||
checkIsEmpty: true,
|
||||
wantIsEmpty: false,
|
||||
checkNotZero: true,
|
||||
},
|
||||
{
|
||||
name: "EmptyArray",
|
||||
input: "[]",
|
||||
wantValues: nil,
|
||||
checkIsEmpty: true,
|
||||
wantIsEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeMap",
|
||||
input: "{key: value}",
|
||||
wantErr: true,
|
||||
errContains: "must be string or array",
|
||||
},
|
||||
{
|
||||
name: "QuotedStringWithSpaces",
|
||||
input: `"step with spaces"`,
|
||||
wantValues: []string{"step with spaces"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.StringOrArray
|
||||
err := yaml.Unmarshal([]byte(tt.input), &s)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.wantValues == nil {
|
||||
assert.Empty(t, s.Values())
|
||||
} else {
|
||||
assert.Equal(t, tt.wantValues, s.Values())
|
||||
}
|
||||
if tt.checkIsEmpty {
|
||||
assert.Equal(t, tt.wantIsEmpty, s.IsEmpty())
|
||||
}
|
||||
if tt.checkNotZero {
|
||||
assert.False(t, s.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ZeroValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.StringOrArray
|
||||
assert.True(t, s.IsZero())
|
||||
assert.Nil(t, s.Values())
|
||||
})
|
||||
}
|
||||
|
||||
func TestStringOrArray_InStruct(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type StepConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Depends types.StringOrArray `yaml:"depends"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantValues []string
|
||||
wantIsZero bool
|
||||
checkIsEmpty bool
|
||||
wantIsEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "DependsAsString",
|
||||
input: `
|
||||
name: step2
|
||||
depends: step1
|
||||
`,
|
||||
wantValues: []string{"step1"},
|
||||
},
|
||||
{
|
||||
name: "DependsAsArray",
|
||||
input: `
|
||||
name: step3
|
||||
depends:
|
||||
- step1
|
||||
- step2
|
||||
`,
|
||||
wantValues: []string{"step1", "step2"},
|
||||
},
|
||||
{
|
||||
name: "DependsNotSet",
|
||||
input: "name: step1",
|
||||
wantIsZero: true,
|
||||
},
|
||||
{
|
||||
name: "DependsEmptyArray",
|
||||
input: `
|
||||
name: step2
|
||||
depends: []
|
||||
`,
|
||||
checkIsEmpty: true,
|
||||
wantIsEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var cfg StepConfig
|
||||
err := yaml.Unmarshal([]byte(tt.input), &cfg)
|
||||
require.NoError(t, err)
|
||||
if tt.wantValues != nil {
|
||||
assert.Equal(t, tt.wantValues, cfg.Depends.Values())
|
||||
}
|
||||
if tt.wantIsZero {
|
||||
assert.True(t, cfg.Depends.IsZero())
|
||||
}
|
||||
if tt.checkIsEmpty {
|
||||
assert.False(t, cfg.Depends.IsZero())
|
||||
assert.Equal(t, tt.wantIsEmpty, cfg.Depends.IsEmpty())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMailToValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantValues []string
|
||||
}{
|
||||
{
|
||||
name: "SingleEmail",
|
||||
input: "user@example.com",
|
||||
wantValues: []string{"user@example.com"},
|
||||
},
|
||||
{
|
||||
name: "MultipleEmails",
|
||||
input: `["user1@example.com", "user2@example.com"]`,
|
||||
wantValues: []string{"user1@example.com", "user2@example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var m types.MailToValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &m)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantValues, m.Values())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagsValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantValues []string
|
||||
}{
|
||||
{
|
||||
name: "SingleTag",
|
||||
input: "production",
|
||||
wantValues: []string{"production"},
|
||||
},
|
||||
{
|
||||
name: "MultipleTags",
|
||||
input: `["production", "critical", "monitored"]`,
|
||||
wantValues: []string{"production", "critical", "monitored"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var tags types.TagsValue
|
||||
err := yaml.Unmarshal([]byte(tt.input), &tags)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantValues, tags.Values())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringOrArray_AdditionalCoverage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
errContains string
|
||||
wantValues []string
|
||||
checkIsZero bool
|
||||
}{
|
||||
{
|
||||
name: "ArrayWithNumericValues",
|
||||
input: "[1, 2, 3]",
|
||||
wantValues: []string{"1", "2", "3"},
|
||||
},
|
||||
{
|
||||
name: "ArrayWithMixedTypes",
|
||||
input: `["step1", 123, true]`,
|
||||
wantValues: []string{"step1", "123", "true"},
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeNumber",
|
||||
input: "123",
|
||||
wantErr: true,
|
||||
errContains: "must be string or array",
|
||||
},
|
||||
{
|
||||
name: "NullValue",
|
||||
input: "null",
|
||||
checkIsZero: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.StringOrArray
|
||||
err := yaml.Unmarshal([]byte(tt.input), &s)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if tt.wantValues != nil {
|
||||
assert.Equal(t, tt.wantValues, s.Values())
|
||||
}
|
||||
if tt.checkIsZero {
|
||||
assert.True(t, s.IsZero())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ValueReturnsRawString", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.StringOrArray
|
||||
err := yaml.Unmarshal([]byte("step1"), &s)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "step1", s.Value())
|
||||
})
|
||||
|
||||
t.Run("ValueReturnsRawArray", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var s types.StringOrArray
|
||||
err := yaml.Unmarshal([]byte(`["step1", "step2"]`), &s)
|
||||
require.NoError(t, err)
|
||||
val, ok := s.Value().([]any)
|
||||
require.True(t, ok)
|
||||
assert.Len(t, val, 2)
|
||||
})
|
||||
}
|
||||
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/dagu-org/dagu/internal/common/cmdutil"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
)
|
||||
|
||||
// loadVariables loads the environment variables from the map.
|
||||
@ -89,6 +90,43 @@ func loadVariables(ctx BuildContext, strVariables any) (
|
||||
return vars, nil
|
||||
}
|
||||
|
||||
// loadVariablesFromEnvValue loads environment variables from a types.EnvValue.
|
||||
// This function converts the typed EnvValue entries to the expected format
|
||||
// and processes them using the same logic as loadVariables.
|
||||
func loadVariablesFromEnvValue(ctx BuildContext, env types.EnvValue) (
|
||||
map[string]string, error,
|
||||
) {
|
||||
if env.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
vars := map[string]string{}
|
||||
for _, entry := range env.Entries() {
|
||||
value := entry.Value
|
||||
|
||||
if !ctx.opts.Has(BuildFlagNoEval) {
|
||||
// Evaluate the value of the environment variable.
|
||||
// This also executes command substitution.
|
||||
// Pass accumulated vars so ${VAR} can reference previously defined vars
|
||||
var err error
|
||||
|
||||
value, err = cmdutil.EvalString(ctx.ctx, value, cmdutil.WithVariables(vars))
|
||||
if err != nil {
|
||||
return nil, core.NewValidationError("env", entry.Value, fmt.Errorf("%w: %s", ErrInvalidEnvValue, entry.Value))
|
||||
}
|
||||
|
||||
// Set the environment variable.
|
||||
if err := os.Setenv(entry.Key, value); err != nil {
|
||||
return nil, core.NewValidationError("env", entry.Key, fmt.Errorf("%w: %s", err, entry.Key))
|
||||
}
|
||||
}
|
||||
|
||||
vars[entry.Key] = value
|
||||
}
|
||||
|
||||
return vars, nil
|
||||
}
|
||||
|
||||
// pair represents a key-value pair.
|
||||
type pair struct {
|
||||
key string
|
||||
|
||||
420
internal/core/spec/variables_test.go
Normal file
420
internal/core/spec/variables_test.go
Normal file
@ -0,0 +1,420 @@
|
||||
package spec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core/spec/types"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseKeyValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]any
|
||||
expected []pair
|
||||
}{
|
||||
{
|
||||
name: "EmptyMap",
|
||||
input: map[string]any{},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "SingleStringValue",
|
||||
input: map[string]any{"FOO": "bar"},
|
||||
expected: []pair{{key: "FOO", val: "bar"}},
|
||||
},
|
||||
{
|
||||
name: "IntegerValue",
|
||||
input: map[string]any{"COUNT": 42},
|
||||
expected: []pair{{key: "COUNT", val: "42"}},
|
||||
},
|
||||
{
|
||||
name: "BooleanValue",
|
||||
input: map[string]any{"DEBUG": true},
|
||||
expected: []pair{{key: "DEBUG", val: "true"}},
|
||||
},
|
||||
{
|
||||
name: "FloatValue",
|
||||
input: map[string]any{"RATIO": 3.14},
|
||||
expected: []pair{{key: "RATIO", val: "3.14"}},
|
||||
},
|
||||
{
|
||||
name: "MultipleValues",
|
||||
input: map[string]any{"A": "1", "B": "2"},
|
||||
expected: []pair{
|
||||
{key: "A", val: "1"},
|
||||
{key: "B", val: "2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var pairs []pair
|
||||
err := parseKeyValue(tt.input, &pairs)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.expected == nil {
|
||||
assert.Empty(t, pairs)
|
||||
return
|
||||
}
|
||||
|
||||
// Since map iteration order is not guaranteed, check by content
|
||||
assert.Len(t, pairs, len(tt.expected))
|
||||
for _, exp := range tt.expected {
|
||||
found := false
|
||||
for _, p := range pairs {
|
||||
if p.key == exp.key && p.val == exp.val {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected pair %v not found", exp)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadVariables(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("MapInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string]any
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
name: "SingleVariable",
|
||||
input: map[string]any{"FOO": "bar"},
|
||||
expected: map[string]string{"FOO": "bar"},
|
||||
},
|
||||
{
|
||||
name: "MultipleVariables",
|
||||
input: map[string]any{"A": "1", "B": "2"},
|
||||
expected: map[string]string{"A": "1", "B": "2"},
|
||||
},
|
||||
{
|
||||
name: "IntegerValue",
|
||||
input: map[string]any{"PORT": 8080},
|
||||
expected: map[string]string{"PORT": "8080"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
result, err := loadVariables(ctx, tt.input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ArrayInput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []any
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
name: "ArrayOfMaps",
|
||||
input: []any{
|
||||
map[string]any{"FOO": "bar"},
|
||||
map[string]any{"BAZ": "qux"},
|
||||
},
|
||||
expected: map[string]string{"FOO": "bar", "BAZ": "qux"},
|
||||
},
|
||||
{
|
||||
name: "ArrayOfStrings",
|
||||
input: []any{
|
||||
"FOO=bar",
|
||||
"BAZ=qux",
|
||||
},
|
||||
expected: map[string]string{"FOO": "bar", "BAZ": "qux"},
|
||||
},
|
||||
{
|
||||
name: "MixedArrayOfMapsAndStrings",
|
||||
input: []any{
|
||||
map[string]any{"FOO": "bar"},
|
||||
"BAZ=qux",
|
||||
},
|
||||
expected: map[string]string{"FOO": "bar", "BAZ": "qux"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
result, err := loadVariables(ctx, tt.input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrorCases", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input any
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "InvalidStringFormat",
|
||||
input: []any{"INVALID_NO_EQUALS"},
|
||||
errContains: "env config should be map of strings or array of key=value",
|
||||
},
|
||||
{
|
||||
name: "InvalidTypeInArray",
|
||||
input: []any{123},
|
||||
errContains: "env config should be map of strings or array of key=value",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
_, err := loadVariables(ctx, tt.input)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WithEvaluation", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{},
|
||||
}
|
||||
|
||||
input := map[string]any{"GREETING": "hello"}
|
||||
result, err := loadVariables(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", result["GREETING"])
|
||||
})
|
||||
|
||||
t.Run("NoEvalFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
// With NoEval, command substitution should not be executed
|
||||
input := map[string]any{"CMD": "$(echo hello)"}
|
||||
result, err := loadVariables(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "$(echo hello)", result["CMD"])
|
||||
})
|
||||
|
||||
t.Run("VariableReference", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{},
|
||||
}
|
||||
|
||||
// Test that later variables can reference earlier ones
|
||||
input := []any{
|
||||
map[string]any{"BASE": "/opt"},
|
||||
map[string]any{"PATH_VAR": "${BASE}/bin"},
|
||||
}
|
||||
result, err := loadVariables(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/opt", result["BASE"])
|
||||
assert.Equal(t, "/opt/bin", result["PATH_VAR"])
|
||||
})
|
||||
|
||||
t.Run("EmptyValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
input := map[string]any{"EMPTY": ""}
|
||||
result, err := loadVariables(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", result["EMPTY"])
|
||||
})
|
||||
|
||||
t.Run("ValueWithEqualsSign", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
input := []any{"KEY=value=with=equals"}
|
||||
result, err := loadVariables(ctx, input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "value=with=equals", result["KEY"])
|
||||
})
|
||||
}
|
||||
|
||||
// Helper to create EnvValue from YAML string
|
||||
func envValueFromYAML(t *testing.T, yamlStr string) types.EnvValue {
|
||||
t.Helper()
|
||||
var env types.EnvValue
|
||||
err := yaml.Unmarshal([]byte(yamlStr), &env)
|
||||
require.NoError(t, err)
|
||||
return env
|
||||
}
|
||||
|
||||
func TestLoadVariablesFromEnvValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyEnvValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
var env types.EnvValue
|
||||
result, err := loadVariablesFromEnvValue(ctx, env)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("MapFormat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
env := envValueFromYAML(t, `
|
||||
FOO: bar
|
||||
BAZ: qux
|
||||
`)
|
||||
result, err := loadVariablesFromEnvValue(ctx, env)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bar", result["FOO"])
|
||||
assert.Equal(t, "qux", result["BAZ"])
|
||||
})
|
||||
|
||||
t.Run("ArrayFormat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
env := envValueFromYAML(t, `
|
||||
- FOO: bar
|
||||
- BAZ: qux
|
||||
`)
|
||||
result, err := loadVariablesFromEnvValue(ctx, env)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bar", result["FOO"])
|
||||
assert.Equal(t, "qux", result["BAZ"])
|
||||
})
|
||||
|
||||
t.Run("WithEvaluation", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{},
|
||||
}
|
||||
|
||||
env := envValueFromYAML(t, `
|
||||
GREETING: hello
|
||||
`)
|
||||
result, err := loadVariablesFromEnvValue(ctx, env)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", result["GREETING"])
|
||||
})
|
||||
|
||||
t.Run("NoEvalFlag", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
env := envValueFromYAML(t, `
|
||||
CMD: "$(echo hello)"
|
||||
`)
|
||||
result, err := loadVariablesFromEnvValue(ctx, env)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "$(echo hello)", result["CMD"])
|
||||
})
|
||||
|
||||
t.Run("VariableReference", func(t *testing.T) {
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{},
|
||||
}
|
||||
|
||||
env := envValueFromYAML(t, `
|
||||
- BASE: /opt
|
||||
- PATH_VAR: "${BASE}/bin"
|
||||
`)
|
||||
result, err := loadVariablesFromEnvValue(ctx, env)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/opt", result["BASE"])
|
||||
assert.Equal(t, "/opt/bin", result["PATH_VAR"])
|
||||
})
|
||||
|
||||
t.Run("IntegerValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
env := envValueFromYAML(t, `
|
||||
PORT: 8080
|
||||
`)
|
||||
result, err := loadVariablesFromEnvValue(ctx, env)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "8080", result["PORT"])
|
||||
})
|
||||
|
||||
t.Run("BooleanValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := BuildContext{
|
||||
ctx: context.Background(),
|
||||
opts: BuildOpts{Flags: BuildFlagNoEval},
|
||||
}
|
||||
|
||||
env := envValueFromYAML(t, `
|
||||
DEBUG: true
|
||||
`)
|
||||
result, err := loadVariablesFromEnvValue(ctx, env)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "true", result["DEBUG"])
|
||||
})
|
||||
}
|
||||
@ -21,30 +21,46 @@ type Step struct {
|
||||
Shell string `json:"shell,omitempty"`
|
||||
// ShellPackages is the list of packages to install. This is used only when the shell is `nix-shell`.
|
||||
ShellPackages []string `json:"shellPackages,omitempty"`
|
||||
// SHell Args is the list of arguments for the shell program.
|
||||
// ShellArgs is the list of arguments for the shell program.
|
||||
ShellArgs []string `json:"shellArgs,omitempty"`
|
||||
// Dir is the working directory for the step.
|
||||
Dir string `json:"dir,omitempty"`
|
||||
// ExecutorConfig contains the configuration for the executor.
|
||||
ExecutorConfig ExecutorConfig `json:"executorConfig,omitzero"`
|
||||
// CmdWithArgs is the command with arguments (only display purpose).
|
||||
// CmdWithArgs is the command with arguments for display purposes.
|
||||
// Deprecated: Use Commands[0].CmdWithArgs instead. Kept for JSON backward compatibility.
|
||||
CmdWithArgs string `json:"cmdWithArgs,omitempty"`
|
||||
// CmdArgsSys is the command with arguments for the system.
|
||||
// Deprecated: Kept for JSON backward compatibility.
|
||||
CmdArgsSys string `json:"cmdArgsSys,omitempty"`
|
||||
// Command specifies only the command without arguments.
|
||||
// Deprecated: Use Commands field instead. Kept for JSON backward compatibility.
|
||||
Command string `json:"command,omitempty"`
|
||||
// ShellCmdArgs is the shell command with arguments.
|
||||
ShellCmdArgs string `json:"shellCmdArgs,omitempty"`
|
||||
// Script is the script to be executed.
|
||||
Script string `json:"script,omitempty"`
|
||||
// Args contains the arguments for the command.
|
||||
// Deprecated: Use Commands field instead. Kept for JSON backward compatibility.
|
||||
Args []string `json:"args,omitempty"`
|
||||
// Commands is the source of truth for commands to execute.
|
||||
// Each entry represents a command to be executed sequentially.
|
||||
// For single commands, this will contain exactly one entry.
|
||||
Commands []CommandEntry `json:"commands,omitempty"`
|
||||
// Stdout is the file to store the standard output.
|
||||
Stdout string `json:"stdout,omitempty"`
|
||||
// Stderr is the file to store the standard error.
|
||||
Stderr string `json:"stderr,omitempty"`
|
||||
// LogOutput specifies how stdout and stderr are handled in log files for this step.
|
||||
// Overrides the DAG-level LogOutput setting. Empty string means inherit from DAG.
|
||||
LogOutput LogOutputMode `json:"logOutput,omitempty"`
|
||||
// Output is the variable name to store the output.
|
||||
Output string `json:"output,omitempty"`
|
||||
// OutputKey is the custom key for the output in outputs.json.
|
||||
// If empty, the Output name is converted from UPPER_CASE to camelCase.
|
||||
OutputKey string `json:"outputKey,omitempty"`
|
||||
// OutputOmit excludes this output from outputs.json when true.
|
||||
OutputOmit bool `json:"outputOmit,omitempty"`
|
||||
// Depends contains the list of step names to depend on.
|
||||
Depends []string `json:"depends,omitempty"`
|
||||
// ExplicitlyNoDeps indicates the depends field was explicitly set to empty
|
||||
@ -74,6 +90,10 @@ type Step struct {
|
||||
// Timeout specifies the maximum execution time for the step.
|
||||
// If set, this timeout takes precedence over the DAG-level timeout for this step.
|
||||
Timeout time.Duration `json:"timeout,omitempty"`
|
||||
// Container specifies the container configuration for this step.
|
||||
// If set, the step runs in its own container instead of the DAG-level container.
|
||||
// This uses the same configuration format as the DAG-level container field.
|
||||
Container *Container `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
// String returns a formatted string representation of the step
|
||||
@ -103,6 +123,68 @@ type SubDAG struct {
|
||||
Params string `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// CommandEntry represents a single command in a multi-command step.
|
||||
// Each entry contains a parsed command with its arguments.
|
||||
type CommandEntry struct {
|
||||
// Command is the executable name or path.
|
||||
Command string `json:"command"`
|
||||
// Args contains the arguments for the command.
|
||||
Args []string `json:"args,omitempty"`
|
||||
// CmdWithArgs is the original command string for display purposes.
|
||||
CmdWithArgs string `json:"cmdWithArgs,omitempty"`
|
||||
}
|
||||
|
||||
// String returns a display string for the command entry.
|
||||
func (c CommandEntry) String() string {
|
||||
if c.CmdWithArgs != "" {
|
||||
return c.CmdWithArgs
|
||||
}
|
||||
if c.Command == "" {
|
||||
return ""
|
||||
}
|
||||
if len(c.Args) == 0 {
|
||||
return c.Command
|
||||
}
|
||||
return c.Command + " " + strings.Join(c.Args, " ")
|
||||
}
|
||||
|
||||
// HasMultipleCommands returns true if the step has multiple commands to execute.
|
||||
func (s *Step) HasMultipleCommands() bool {
|
||||
return len(s.Commands) > 1
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for backward compatibility.
|
||||
// It handles old JSON format where command/args fields were used instead of commands.
|
||||
func (s *Step) UnmarshalJSON(data []byte) error {
|
||||
// Use type alias to avoid infinite recursion
|
||||
type Alias Step
|
||||
aux := &struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(s),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If Commands is already populated, we're done (new format)
|
||||
if len(s.Commands) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Migrate legacy fields to Commands only when legacy command data exists.
|
||||
if s.Command != "" || len(s.Args) > 0 || s.CmdWithArgs != "" {
|
||||
s.Commands = []CommandEntry{{
|
||||
Command: s.Command,
|
||||
Args: s.Args,
|
||||
CmdWithArgs: s.CmdWithArgs,
|
||||
}}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecutorConfig contains the configuration for the executor.
|
||||
type ExecutorConfig struct {
|
||||
// Type represents one of the registered executors.
|
||||
|
||||
@ -276,3 +276,43 @@ func TestRepeatPolicy_MarshalUnmarshal(t *testing.T) {
|
||||
assert.Equal(t, rp.RepeatMode, rp2.RepeatMode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCommandEntry_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
entry CommandEntry
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty",
|
||||
entry: CommandEntry{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "CommandOnly",
|
||||
entry: CommandEntry{Command: "echo"},
|
||||
expected: "echo",
|
||||
},
|
||||
{
|
||||
name: "CommandWithArgs",
|
||||
entry: CommandEntry{Command: "echo", Args: []string{"hello", "world"}},
|
||||
expected: "echo hello world",
|
||||
},
|
||||
{
|
||||
name: "CmdWithArgsTakesPriority",
|
||||
entry: CommandEntry{Command: "echo", Args: []string{"hello"}, CmdWithArgs: "echo 'hello world'"},
|
||||
expected: "echo 'hello world'",
|
||||
},
|
||||
{
|
||||
name: "CmdWithArgsOnly",
|
||||
entry: CommandEntry{CmdWithArgs: "ls -la"},
|
||||
expected: "ls -la",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.entry.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,6 +6,28 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DAGNameMaxLen defines the maximum allowed length for a DAG name.
|
||||
const DAGNameMaxLen = 40
|
||||
|
||||
// dagNameRegex matches valid DAG names: alphanumeric, underscore, dash, dot.
|
||||
var dagNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_.-]+$`)
|
||||
|
||||
// ValidateDAGName validates a DAG name according to shared rules.
|
||||
// - Empty name is allowed (caller may provide one via context or filename).
|
||||
// - Non-empty name must satisfy length and allowed character constraints.
|
||||
func ValidateDAGName(name string) error {
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
if len(name) > DAGNameMaxLen {
|
||||
return ErrNameTooLong
|
||||
}
|
||||
if !dagNameRegex.MatchString(name) {
|
||||
return ErrNameInvalidChars
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// StepValidator is a function type for validating step configurations.
|
||||
type StepValidator func(step Step) error
|
||||
|
||||
|
||||
774
internal/core/validator_test.go
Normal file
774
internal/core/validator_test.go
Normal file
@ -0,0 +1,774 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsValidStepID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
expected bool
|
||||
}{
|
||||
// Valid cases - starts with letter, followed by alphanumeric/dash/underscore
|
||||
{name: "single letter", id: "a", expected: true},
|
||||
{name: "simple word", id: "step", expected: true},
|
||||
{name: "word with number", id: "step1", expected: true},
|
||||
{name: "word with dash", id: "my-step", expected: true},
|
||||
{name: "word with underscore", id: "my_step", expected: true},
|
||||
{name: "mixed case", id: "MyStep", expected: true},
|
||||
{name: "uppercase", id: "STEP", expected: true},
|
||||
{name: "complex valid id", id: "Step123-test_id", expected: true},
|
||||
{name: "letters and numbers", id: "step123abc", expected: true},
|
||||
{name: "uppercase with numbers", id: "STEP123", expected: true},
|
||||
|
||||
// Invalid cases
|
||||
{name: "starts with number", id: "1step", expected: false},
|
||||
{name: "starts with dash", id: "-step", expected: false},
|
||||
{name: "starts with underscore", id: "_step", expected: false},
|
||||
{name: "contains space", id: "step name", expected: false},
|
||||
{name: "contains exclamation", id: "step!", expected: false},
|
||||
{name: "contains at sign", id: "step@test", expected: false},
|
||||
{name: "contains dot", id: "step.name", expected: false},
|
||||
{name: "empty string", id: "", expected: false},
|
||||
{name: "only numbers", id: "123", expected: false},
|
||||
{name: "contains slash", id: "step/name", expected: false},
|
||||
{name: "contains colon", id: "step:name", expected: false},
|
||||
{name: "contains equals", id: "step=value", expected: false},
|
||||
{name: "unicode characters", id: "step日本語", expected: false},
|
||||
{name: "emoji", id: "step🚀", expected: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := isValidStepID(tt.id)
|
||||
assert.Equal(t, tt.expected, result,
|
||||
"isValidStepID(%q) = %v, want %v", tt.id, result, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReservedWord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// All reserved words (case insensitive)
|
||||
reservedWords := []string{"env", "params", "args", "stdout", "stderr", "output", "outputs"}
|
||||
|
||||
t.Run("reserved words are detected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, word := range reservedWords {
|
||||
assert.True(t, isReservedWord(word),
|
||||
"isReservedWord(%q) should return true", word)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reserved words uppercase are detected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, word := range reservedWords {
|
||||
upper := strings.ToUpper(word)
|
||||
assert.True(t, isReservedWord(upper),
|
||||
"isReservedWord(%q) should return true", upper)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reserved words mixed case are detected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mixedCases := []string{"Env", "PARAMS", "Args", "StdOut", "StdErr", "Output", "Outputs"}
|
||||
for _, word := range mixedCases {
|
||||
assert.True(t, isReservedWord(word),
|
||||
"isReservedWord(%q) should return true", word)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-reserved words are not detected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
nonReserved := []string{
|
||||
"environment",
|
||||
"parameter",
|
||||
"arguments",
|
||||
"step",
|
||||
"run",
|
||||
"execute",
|
||||
"command",
|
||||
"envs",
|
||||
"param",
|
||||
"arg",
|
||||
"out",
|
||||
"err",
|
||||
"myenv",
|
||||
"test-stdout",
|
||||
}
|
||||
for _, word := range nonReserved {
|
||||
assert.False(t, isReservedWord(word),
|
||||
"isReservedWord(%q) should return false", word)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty string is not reserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.False(t, isReservedWord(""))
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateSteps(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a non-empty executor type to avoid triggering command validators
|
||||
// that may be registered via init() from other packages
|
||||
testExec := ExecutorConfig{Type: "test-no-validator"}
|
||||
|
||||
t.Run("valid DAG with steps passes validation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ExecutorConfig: testExec},
|
||||
{Name: "step2", Depends: []string{"step1"}, ExecutorConfig: testExec},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty DAG passes validation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{Steps: []Step{}}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
// Pass 1: ID validation tests
|
||||
t.Run("step with valid ID passes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ID: "myStepId", ExecutorConfig: testExec},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("step with invalid ID format fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ID: "1invalid"}, // starts with number
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid step ID format")
|
||||
})
|
||||
|
||||
t.Run("step with reserved word ID fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
reservedWords := []string{"env", "params", "args", "stdout", "stderr", "output", "outputs"}
|
||||
for _, word := range reservedWords {
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ID: word},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err, "ID %q should be rejected as reserved", word)
|
||||
assert.Contains(t, err.Error(), "reserved word")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("duplicate step names fail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "duplicate"},
|
||||
{Name: "duplicate"},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrStepNameDuplicate))
|
||||
})
|
||||
|
||||
t.Run("duplicate step IDs fail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ID: "sameId"},
|
||||
{Name: "step2", ID: "sameId"},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "duplicate step ID")
|
||||
})
|
||||
|
||||
// Pass 2: Name/ID conflict tests
|
||||
t.Run("step ID conflicts with another step name fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "conflictName", ExecutorConfig: testExec},
|
||||
{Name: "step2", ID: "conflictName", ExecutorConfig: testExec},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
// The validator detects that step name "conflictName" conflicts with another step's ID
|
||||
assert.Contains(t, err.Error(), "conflicts")
|
||||
})
|
||||
|
||||
t.Run("step name conflicts with another step ID fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ID: "conflictId", ExecutorConfig: testExec},
|
||||
{Name: "conflictId", ExecutorConfig: testExec},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
// The validator detects that step ID "conflictId" conflicts with another step's name
|
||||
assert.Contains(t, err.Error(), "conflicts")
|
||||
})
|
||||
|
||||
t.Run("same step has matching name and ID is allowed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "sameName", ID: "sameName", ExecutorConfig: testExec},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
// Pass 3 & 4: Dependency tests
|
||||
t.Run("valid dependencies pass", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ExecutorConfig: testExec},
|
||||
{Name: "step2", Depends: []string{"step1"}, ExecutorConfig: testExec},
|
||||
{Name: "step3", Depends: []string{"step1", "step2"}, ExecutorConfig: testExec},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("non-existent dependency fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1"},
|
||||
{Name: "step2", Depends: []string{"nonexistent"}},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "non-existent step")
|
||||
})
|
||||
|
||||
t.Run("ID reference in depends is resolved to name", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ID: "s1", ExecutorConfig: testExec},
|
||||
{Name: "step2", Depends: []string{"s1"}, ExecutorConfig: testExec}, // references ID
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
// After validation, depends should be resolved to name
|
||||
assert.Contains(t, dag.Steps[1].Depends, "step1")
|
||||
})
|
||||
|
||||
// Pass 5: Step validation tests
|
||||
t.Run("step with empty name fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: ""},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
// The internal error is "step name not generated"
|
||||
assert.Contains(t, err.Error(), "step name")
|
||||
})
|
||||
|
||||
t.Run("step name too long fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
longName := strings.Repeat("a", 41) // 41 chars, max is 40
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: longName},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrStepNameTooLong))
|
||||
})
|
||||
|
||||
t.Run("step name at max length passes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
maxName := strings.Repeat("a", 40) // exactly 40 chars
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: maxName, ExecutorConfig: testExec},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
// Parallel config validation
|
||||
t.Run("parallel config without SubDAG fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
Parallel: &ParallelConfig{
|
||||
MaxConcurrent: 2,
|
||||
Items: []ParallelItem{{Value: "a"}, {Value: "b"}},
|
||||
},
|
||||
// SubDAG is nil
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "only supported for child-DAGs")
|
||||
})
|
||||
|
||||
t.Run("parallel config with maxConcurrent 0 fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
Parallel: &ParallelConfig{
|
||||
MaxConcurrent: 0,
|
||||
Items: []ParallelItem{{Value: "a"}, {Value: "b"}},
|
||||
},
|
||||
SubDAG: &SubDAG{Name: "child"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "maxConcurrent must be greater than 0")
|
||||
})
|
||||
|
||||
t.Run("parallel config with negative maxConcurrent fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
Parallel: &ParallelConfig{
|
||||
MaxConcurrent: -1,
|
||||
Items: []ParallelItem{{Value: "a"}, {Value: "b"}},
|
||||
},
|
||||
SubDAG: &SubDAG{Name: "child"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "maxConcurrent must be greater than 0")
|
||||
})
|
||||
|
||||
t.Run("parallel config without items or variable fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
Parallel: &ParallelConfig{
|
||||
MaxConcurrent: 2,
|
||||
// no items, no variable
|
||||
},
|
||||
SubDAG: &SubDAG{Name: "child"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must have either items array or variable reference")
|
||||
})
|
||||
|
||||
t.Run("valid parallel config with items passes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
Parallel: &ParallelConfig{
|
||||
MaxConcurrent: 2,
|
||||
Items: []ParallelItem{{Value: "a"}, {Value: "b"}, {Value: "c"}},
|
||||
},
|
||||
SubDAG: &SubDAG{Name: "child"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("valid parallel config with variable passes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
Parallel: &ParallelConfig{
|
||||
MaxConcurrent: 2,
|
||||
Variable: "ITEMS",
|
||||
},
|
||||
SubDAG: &SubDAG{Name: "child"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterStepValidator(t *testing.T) {
|
||||
// Note: These tests modify global state, so they should not run in parallel
|
||||
// with each other. Each test should clean up after itself.
|
||||
|
||||
t.Run("register validator for new type", func(t *testing.T) {
|
||||
// Clean up after test
|
||||
defer delete(stepValidators, "test-executor")
|
||||
|
||||
validatorCalled := false
|
||||
validator := func(_ Step) error {
|
||||
validatorCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
RegisterStepValidator("test-executor", validator)
|
||||
|
||||
// Create a DAG with a step using this executor type
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
ExecutorConfig: ExecutorConfig{Type: "test-executor"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, validatorCalled, "validator should have been called")
|
||||
})
|
||||
|
||||
t.Run("validator returning error propagates", func(t *testing.T) {
|
||||
defer delete(stepValidators, "error-executor")
|
||||
|
||||
expectedErr := errors.New("validation failed")
|
||||
validator := func(_ Step) error {
|
||||
return expectedErr
|
||||
}
|
||||
|
||||
RegisterStepValidator("error-executor", validator)
|
||||
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
ExecutorConfig: ExecutorConfig{Type: "error-executor"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := ValidateSteps(dag)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "validation failed")
|
||||
})
|
||||
|
||||
t.Run("overwrite existing validator", func(t *testing.T) {
|
||||
defer delete(stepValidators, "overwrite-executor")
|
||||
|
||||
firstCalled := false
|
||||
secondCalled := false
|
||||
|
||||
first := func(_ Step) error {
|
||||
firstCalled = true
|
||||
return nil
|
||||
}
|
||||
second := func(_ Step) error {
|
||||
secondCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
RegisterStepValidator("overwrite-executor", first)
|
||||
RegisterStepValidator("overwrite-executor", second)
|
||||
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
ExecutorConfig: ExecutorConfig{Type: "overwrite-executor"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, firstCalled, "first validator should not be called")
|
||||
assert.True(t, secondCalled, "second validator should be called")
|
||||
})
|
||||
|
||||
t.Run("no validator for type does not fail", func(t *testing.T) {
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{
|
||||
Name: "step1",
|
||||
ExecutorConfig: ExecutorConfig{Type: "unregistered-executor"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveStepDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("resolves ID references to names", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "firstStep", ID: "first"},
|
||||
{Name: "secondStep", ID: "second"},
|
||||
{Name: "thirdStep", Depends: []string{"first", "second"}},
|
||||
},
|
||||
}
|
||||
|
||||
err := resolveStepDependencies(dag)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Dependencies should be resolved to names
|
||||
assert.Contains(t, dag.Steps[2].Depends, "firstStep")
|
||||
assert.Contains(t, dag.Steps[2].Depends, "secondStep")
|
||||
})
|
||||
|
||||
t.Run("preserves name references", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1"},
|
||||
{Name: "step2", Depends: []string{"step1"}}, // uses name, not ID
|
||||
},
|
||||
}
|
||||
|
||||
err := resolveStepDependencies(dag)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Name reference should remain unchanged
|
||||
assert.Contains(t, dag.Steps[1].Depends, "step1")
|
||||
})
|
||||
|
||||
t.Run("mixed ID and name references", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ID: "s1"},
|
||||
{Name: "step2"},
|
||||
{Name: "step3", Depends: []string{"s1", "step2"}}, // mix of ID and name
|
||||
},
|
||||
}
|
||||
|
||||
err := resolveStepDependencies(dag)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ID should be resolved, name should remain
|
||||
assert.Contains(t, dag.Steps[2].Depends, "step1") // s1 resolved to step1
|
||||
assert.Contains(t, dag.Steps[2].Depends, "step2") // step2 unchanged
|
||||
})
|
||||
|
||||
t.Run("empty DAG", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{Steps: []Step{}}
|
||||
|
||||
err := resolveStepDependencies(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("steps without dependencies", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "step1", ID: "s1"},
|
||||
{Name: "step2", ID: "s2"},
|
||||
},
|
||||
}
|
||||
|
||||
err := resolveStepDependencies(dag)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, dag.Steps[0].Depends)
|
||||
assert.Empty(t, dag.Steps[1].Depends)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateStep(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a non-empty executor type to avoid triggering command validators
|
||||
testExecutorType := "test-no-validator"
|
||||
|
||||
t.Run("valid step passes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
step := Step{Name: "validStep", ExecutorConfig: ExecutorConfig{Type: testExecutorType}}
|
||||
err := validateStep(step)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty name fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
step := Step{Name: "", ExecutorConfig: ExecutorConfig{Type: testExecutorType}}
|
||||
err := validateStep(step)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrStepNameRequired))
|
||||
})
|
||||
|
||||
t.Run("name too long fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
step := Step{Name: strings.Repeat("x", 41), ExecutorConfig: ExecutorConfig{Type: testExecutorType}}
|
||||
err := validateStep(step)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrStepNameTooLong))
|
||||
})
|
||||
|
||||
t.Run("name at exactly max length passes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
step := Step{Name: strings.Repeat("x", 40), ExecutorConfig: ExecutorConfig{Type: testExecutorType}}
|
||||
err := validateStep(step)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateStepWithValidator(t *testing.T) {
|
||||
t.Run("no validator returns nil", func(t *testing.T) {
|
||||
step := Step{
|
||||
Name: "step1",
|
||||
ExecutorConfig: ExecutorConfig{Type: "unknown-type"},
|
||||
}
|
||||
err := validateStepWithValidator(step)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil validator returns nil", func(t *testing.T) {
|
||||
defer delete(stepValidators, "nil-validator-type")
|
||||
stepValidators["nil-validator-type"] = nil
|
||||
|
||||
step := Step{
|
||||
Name: "step1",
|
||||
ExecutorConfig: ExecutorConfig{Type: "nil-validator-type"},
|
||||
}
|
||||
err := validateStepWithValidator(step)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("validator error is wrapped", func(t *testing.T) {
|
||||
defer delete(stepValidators, "wrap-error-type")
|
||||
|
||||
customErr := errors.New("custom validation error")
|
||||
stepValidators["wrap-error-type"] = func(_ Step) error {
|
||||
return customErr
|
||||
}
|
||||
|
||||
step := Step{
|
||||
Name: "step1",
|
||||
ExecutorConfig: ExecutorConfig{Type: "wrap-error-type"},
|
||||
}
|
||||
err := validateStepWithValidator(step)
|
||||
require.Error(t, err)
|
||||
|
||||
// Should be wrapped in ValidationError
|
||||
var ve *ValidationError
|
||||
require.True(t, errors.As(err, &ve))
|
||||
assert.Equal(t, "executorConfig", ve.Field)
|
||||
assert.True(t, errors.Is(err, customErr))
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateSteps_ComplexScenarios(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Use a non-empty executor type to avoid triggering command validators
|
||||
// that may be registered via init() from other packages
|
||||
testExecutorType := "test-no-validator"
|
||||
|
||||
t.Run("large DAG with many steps", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a DAG with 100 steps in a chain
|
||||
steps := make([]Step, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
steps[i] = Step{
|
||||
Name: fmt.Sprintf("step%d", i),
|
||||
ExecutorConfig: ExecutorConfig{Type: testExecutorType},
|
||||
}
|
||||
if i > 0 {
|
||||
steps[i].Depends = []string{fmt.Sprintf("step%d", i-1)}
|
||||
}
|
||||
}
|
||||
|
||||
dag := &DAG{Steps: steps}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("diamond dependency pattern", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// A
|
||||
// / \
|
||||
// B C
|
||||
// \ /
|
||||
// D
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
{Name: "A", ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
|
||||
{Name: "B", Depends: []string{"A"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
|
||||
{Name: "C", Depends: []string{"A"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
|
||||
{Name: "D", Depends: []string{"B", "C"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("multiple independent chains", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := &DAG{
|
||||
Steps: []Step{
|
||||
// Chain 1
|
||||
{Name: "chain1-step1", ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
|
||||
{Name: "chain1-step2", Depends: []string{"chain1-step1"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
|
||||
// Chain 2
|
||||
{Name: "chain2-step1", ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
|
||||
{Name: "chain2-step2", Depends: []string{"chain2-step1"}, ExecutorConfig: ExecutorConfig{Type: testExecutorType}},
|
||||
},
|
||||
}
|
||||
err := ValidateSteps(dag)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@ -32,6 +32,8 @@ type dockerExecutorTest struct {
|
||||
}
|
||||
|
||||
func TestDockerExecutor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []dockerExecutorTest{
|
||||
{
|
||||
name: "BasicExecution",
|
||||
@ -90,6 +92,8 @@ type containerTest struct {
|
||||
}
|
||||
|
||||
func TestDAGLevelContainer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []containerTest{
|
||||
{
|
||||
name: "VolumeBindMounts",
|
||||
@ -237,8 +241,6 @@ steps:
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempDir, err := os.MkdirTemp("", fmt.Sprintf("%s-%s-*", containerPrefix, tt.name))
|
||||
require.NoError(t, err, "failed to create temporary directory")
|
||||
t.Cleanup(func() { _ = os.RemoveAll(tempDir) })
|
||||
@ -513,3 +515,233 @@ func waitForContainerStop(t *testing.T, th test.Helper, dockerClient *client.Cli
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestStepLevelContainer tests the new step-level container syntax
|
||||
// which allows specifying a container field directly on a step instead of
|
||||
// using the executor syntax.
|
||||
func TestStepLevelContainer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []containerTest{
|
||||
{
|
||||
name: "BasicStepContainer",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: run-in-container
|
||||
container:
|
||||
image: %s
|
||||
command: echo "hello from step container"
|
||||
output: STEP_CONTAINER_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"STEP_CONTAINER_OUT": "hello from step container",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StepContainerWithWorkingDir",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: check-workdir
|
||||
container:
|
||||
image: %s
|
||||
workingDir: /tmp
|
||||
command: pwd
|
||||
output: STEP_WORKDIR_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"STEP_WORKDIR_OUT": "/tmp",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StepContainerWithEnv",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: check-env
|
||||
container:
|
||||
image: %s
|
||||
env:
|
||||
- MY_VAR=hello_world
|
||||
command: sh -c "echo $MY_VAR"
|
||||
output: STEP_ENV_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"STEP_ENV_OUT": "hello_world",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StepContainerWithVolume",
|
||||
dagConfigFunc: func(tempDir string) string {
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: write-file
|
||||
container:
|
||||
image: %s
|
||||
volumes:
|
||||
- %s:/data
|
||||
command: sh -c "echo 'step volume test' > /data/step_test.txt"
|
||||
- name: read-file
|
||||
container:
|
||||
image: %s
|
||||
volumes:
|
||||
- %s:/data
|
||||
command: cat /data/step_test.txt
|
||||
output: STEP_VOL_OUT
|
||||
depends:
|
||||
- write-file
|
||||
`, testImage, tempDir, testImage, tempDir)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"STEP_VOL_OUT": "step volume test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MultipleStepsWithDifferentContainers",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: alpine-step
|
||||
container:
|
||||
image: %s
|
||||
command: cat /etc/alpine-release
|
||||
output: ALPINE_VERSION
|
||||
- name: busybox-step
|
||||
container:
|
||||
image: busybox:latest
|
||||
command: echo "busybox step"
|
||||
output: BUSYBOX_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"BUSYBOX_OUT": "busybox step",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StepContainerOverridesDAGContainer",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
return fmt.Sprintf(`
|
||||
# DAG-level container - steps without container field use this
|
||||
container:
|
||||
image: busybox:latest
|
||||
|
||||
steps:
|
||||
- name: use-dag-container
|
||||
command: echo "in DAG container"
|
||||
output: DAG_CONTAINER_OUT
|
||||
- name: use-step-container
|
||||
container:
|
||||
image: %s
|
||||
command: cat /etc/alpine-release
|
||||
output: STEP_CONTAINER_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"DAG_CONTAINER_OUT": "in DAG container",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StepContainerWithUser",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: check-user
|
||||
container:
|
||||
image: %s
|
||||
user: "nobody"
|
||||
command: whoami
|
||||
output: STEP_USER_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"STEP_USER_OUT": "nobody",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StepContainerWithPullPolicy",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: pull-never
|
||||
container:
|
||||
image: %s
|
||||
pullPolicy: never
|
||||
command: echo "pull never ok"
|
||||
output: PULL_NEVER_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"PULL_NEVER_OUT": "pull never ok",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StepEnvMergedIntoContainer",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
// Test that step.env is merged with container.env
|
||||
// container.env takes precedence for shared keys
|
||||
// Use printenv to show actual environment in container
|
||||
// Note: SEMIC_ prefix is an abbreviation of the test name (StepEnvMergedIntoContainerEnv)
|
||||
// to avoid environment variable collisions between tests
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: check-merged-env
|
||||
env:
|
||||
- SEMIC_STEP_VAR=from_step
|
||||
- SEMIC_SHARED_VAR=step_value
|
||||
container:
|
||||
image: %s
|
||||
env:
|
||||
- SEMIC_CONTAINER_VAR=from_container
|
||||
- SEMIC_SHARED_VAR=container_value
|
||||
command: printenv SEMIC_SHARED_VAR
|
||||
output: SEMIC_MERGED_ENV_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
// SEMIC_SHARED_VAR should be container_value (container.env takes precedence)
|
||||
"SEMIC_MERGED_ENV_OUT": "container_value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "StepEnvOnlyPassedToContainer",
|
||||
dagConfigFunc: func(_ string) string {
|
||||
// Test that step.env is passed to container even without container.env
|
||||
return fmt.Sprintf(`
|
||||
steps:
|
||||
- name: step-env-only
|
||||
env:
|
||||
- MY_STEP_VAR=hello_from_step
|
||||
container:
|
||||
image: %s
|
||||
command: printenv MY_STEP_VAR
|
||||
output: STEP_ENV_ONLY_OUT
|
||||
`, testImage)
|
||||
},
|
||||
expectedOutputs: map[string]any{
|
||||
"STEP_ENV_ONLY_OUT": "hello_from_step",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", fmt.Sprintf("%s-step-%s-*", containerPrefix, tt.name))
|
||||
require.NoError(t, err, "failed to create temporary directory")
|
||||
t.Cleanup(func() { _ = os.RemoveAll(tempDir) })
|
||||
|
||||
if tt.setupFunc != nil {
|
||||
tt.setupFunc(t, tempDir)
|
||||
}
|
||||
|
||||
th := test.Setup(t)
|
||||
dag := th.DAG(t, tt.dagConfigFunc(tempDir))
|
||||
dag.Agent().RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
dag.AssertOutputs(t, tt.expectedOutputs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,15 +50,15 @@ steps:
|
||||
// Load the DAG
|
||||
dagWrapper := coord.DAG(t, yamlContent)
|
||||
|
||||
// Build the start command spec
|
||||
// Build the enqueue command spec
|
||||
subCmdBuilder := runtime.NewSubCmdBuilder(coord.Config)
|
||||
startSpec := subCmdBuilder.Start(dagWrapper.DAG, runtime.StartOptions{
|
||||
enqueueSpec := subCmdBuilder.Enqueue(dagWrapper.DAG, runtime.EnqueueOptions{
|
||||
Quiet: true,
|
||||
})
|
||||
|
||||
// Execute the start command (spawns subprocess)
|
||||
err := runtime.Start(coord.Context, startSpec)
|
||||
require.NoError(t, err, "Start command should succeed")
|
||||
// Execute the enqueue command (spawns subprocess)
|
||||
err := runtime.Start(coord.Context, enqueueSpec)
|
||||
require.NoError(t, err, "Enqueue command should succeed")
|
||||
|
||||
// Wait for the subprocess to complete enqueueing
|
||||
require.Eventually(t, func() bool {
|
||||
@ -138,10 +138,10 @@ steps:
|
||||
t.Log("E2E test completed successfully!")
|
||||
})
|
||||
|
||||
t.Run("E2E_StartCommand_WithNoQueueFlag_ShouldExecuteDirectly", func(t *testing.T) {
|
||||
// Verify that --no-queue flag bypasses enqueueing even when workerSelector exists
|
||||
t.Run("E2E_StartCommand_WorkerSelector_ShouldExecuteLocally", func(t *testing.T) {
|
||||
// Verify that dagu start always executes locally even when workerSelector exists
|
||||
yamlContent := `
|
||||
name: no-queue-dag
|
||||
name: local-start-dag
|
||||
workerSelector:
|
||||
test: value
|
||||
steps:
|
||||
@ -154,11 +154,10 @@ steps:
|
||||
// Load the DAG
|
||||
dagWrapper := coord.DAG(t, yamlContent)
|
||||
|
||||
// Build start command WITH --no-queue flag
|
||||
// Build start command
|
||||
subCmdBuilder := runtime.NewSubCmdBuilder(coord.Config)
|
||||
startSpec := subCmdBuilder.Start(dagWrapper.DAG, runtime.StartOptions{
|
||||
Quiet: true,
|
||||
NoQueue: true, // This bypasses enqueueing
|
||||
Quiet: true,
|
||||
})
|
||||
|
||||
err := runtime.Start(ctx, startSpec)
|
||||
@ -170,7 +169,7 @@ steps:
|
||||
// Should NOT be enqueued (executed directly)
|
||||
queueItems, err := coord.QueueStore.ListByDAGName(ctx, dagWrapper.ProcGroup(), dagWrapper.Name)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queueItems, 0, "DAG should NOT be enqueued when --no-queue is set")
|
||||
require.Len(t, queueItems, 0, "DAG should NOT be enqueued (dagu start runs locally)")
|
||||
})
|
||||
|
||||
t.Run("E2E_DistributedExecution_Cancellation_SubDAG", func(t *testing.T) {
|
||||
|
||||
@ -5,20 +5,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/execution"
|
||||
"github.com/dagu-org/dagu/internal/runtime"
|
||||
"github.com/dagu-org/dagu/internal/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestStartCommandWithWorkerSelector tests that the start command enqueues
|
||||
// DAGs with workerSelector instead of executing them locally, and that the
|
||||
// scheduler dispatches them to workers correctly.
|
||||
//
|
||||
// This is the integration test for the distributed execution fix where:
|
||||
// 1. start command checks for workerSelector → enqueues (instead of executing)
|
||||
// 2. Scheduler queue handler picks it up → dispatches to coordinator
|
||||
// 3. Worker executes with --no-queue flag → executes directly (no re-enqueue)
|
||||
// 3. Worker executes directly as ordered by the scheduler (no re-enqueue)
|
||||
func TestStartCommandWithWorkerSelector(t *testing.T) {
|
||||
t.Run("StartCommand_WithWorkerSelector_ShouldEnqueue", func(t *testing.T) {
|
||||
// This test verifies that when a DAG has workerSelector,
|
||||
@ -53,16 +45,16 @@ steps:
|
||||
err := runtime.Start(coord.Context, startSpec)
|
||||
require.NoError(t, err, "Start command should succeed")
|
||||
|
||||
// Wait for the DAG to be enqueued
|
||||
// Wait for completion (executed locally)
|
||||
require.Eventually(t, func() bool {
|
||||
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
|
||||
return err == nil && len(queueItems) == 1
|
||||
}, 2*time.Second, 50*time.Millisecond, "DAG should be enqueued")
|
||||
status, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
|
||||
return err == nil && status.Status == core.Succeeded
|
||||
}, 2*time.Second, 50*time.Millisecond, "DAG should complete successfully")
|
||||
|
||||
// Verify the DAG was enqueued (not executed locally)
|
||||
// Verify the DAG was NOT enqueued (executed locally)
|
||||
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queueItems, 1, "DAG should be enqueued once")
|
||||
require.Len(t, queueItems, 0, "DAG should NOT be enqueued (dagu start runs locally)")
|
||||
|
||||
if len(queueItems) > 0 {
|
||||
data, err := queueItems[0].Data()
|
||||
@ -70,18 +62,18 @@ steps:
|
||||
t.Logf("DAG enqueued: dag=%s runId=%s", data.Name, data.ID)
|
||||
}
|
||||
|
||||
// Verify the DAG status is "queued" (not started/running)
|
||||
// Verify the DAG status is "succeeded" (executed locally)
|
||||
latest, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, core.Queued, latest.Status, "DAG status should be queued")
|
||||
require.Equal(t, core.Succeeded, latest.Status, "DAG status should be succeeded")
|
||||
})
|
||||
|
||||
t.Run("StartCommand_WithNoQueueFlag_ShouldExecuteDirectly", func(t *testing.T) {
|
||||
// Verify that --no-queue flag bypasses enqueueing
|
||||
t.Run("StartCommand_WorkerSelector_ShouldExecuteLocally", func(t *testing.T) {
|
||||
// Verify that dagu start always executes locally
|
||||
// even when workerSelector exists
|
||||
|
||||
yamlContent := `
|
||||
name: no-queue-dag
|
||||
name: local-start-dag
|
||||
workerSelector:
|
||||
test: value
|
||||
steps:
|
||||
@ -94,11 +86,10 @@ steps:
|
||||
// Load the DAG
|
||||
dagWrapper := coord.DAG(t, yamlContent)
|
||||
|
||||
// Build start command WITH --no-queue flag
|
||||
// Build start command
|
||||
subCmdBuilder := runtime.NewSubCmdBuilder(coord.Config)
|
||||
startSpec := subCmdBuilder.Start(dagWrapper.DAG, runtime.StartOptions{
|
||||
Quiet: true,
|
||||
NoQueue: true, // This bypasses enqueueing
|
||||
Quiet: true,
|
||||
})
|
||||
|
||||
err := runtime.Start(ctx, startSpec)
|
||||
@ -107,7 +98,7 @@ steps:
|
||||
// Should NOT be enqueued (executed directly)
|
||||
queueItems, err := coord.QueueStore.ListByDAGName(ctx, dagWrapper.ProcGroup(), dagWrapper.Name)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queueItems, 0, "DAG should NOT be enqueued when --no-queue is set")
|
||||
require.Len(t, queueItems, 0, "DAG should NOT be enqueued (dagu start runs locally)")
|
||||
|
||||
// Verify it succeeded (executed locally)
|
||||
dagWrapper.AssertLatestStatus(t, core.Succeeded)
|
||||
@ -140,54 +131,39 @@ steps:
|
||||
Quiet: true,
|
||||
})
|
||||
|
||||
// Execute the start command (runs locally now)
|
||||
err := runtime.Start(coord.Context, startSpec)
|
||||
require.NoError(t, err, "Start command should succeed")
|
||||
require.NoError(t, err, "Start command should succeed (process started)")
|
||||
|
||||
// Wait for the DAG to be enqueued
|
||||
// Wait for completion (executed locally)
|
||||
require.Eventually(t, func() bool {
|
||||
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
|
||||
return err == nil && len(queueItems) == 1
|
||||
}, 2*time.Second, 50*time.Millisecond, "DAG should be enqueued")
|
||||
status, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
|
||||
return err == nil && status.Status == core.Failed
|
||||
}, 5*time.Second, 100*time.Millisecond, "DAG should fail")
|
||||
|
||||
// Verify the DAG was enqueued
|
||||
// Should NOT be enqueued
|
||||
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queueItems, 1, "DAG should be enqueued once")
|
||||
require.Len(t, queueItems, 0, "DAG should NOT be enqueued (dagu start runs locally)")
|
||||
|
||||
var dagRunID string
|
||||
var dagRun execution.DAGRunRef
|
||||
if len(queueItems) > 0 {
|
||||
data, err := queueItems[0].Data()
|
||||
require.NoError(t, err, "Should be able to get queue item data")
|
||||
dagRunID = data.ID
|
||||
dagRun = *data
|
||||
t.Logf("DAG enqueued: dag=%s runId=%s", data.Name, data.ID)
|
||||
}
|
||||
|
||||
// Dequeue it to simulate processing
|
||||
_, err = coord.QueueStore.DequeueByDAGRunID(coord.Context, dagWrapper.ProcGroup(), dagRun)
|
||||
status, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
|
||||
require.NoError(t, err)
|
||||
dagRunID := status.DAGRunID
|
||||
t.Logf("DAG failed: dag=%s runId=%s", status.Name, status.DAGRunID)
|
||||
|
||||
// Now retry the DAG - it should be enqueued again
|
||||
retrySpec := subCmdBuilder.Retry(dagWrapper.DAG, dagRunID, "", false)
|
||||
err = runtime.Run(coord.Context, retrySpec)
|
||||
require.NoError(t, err, "Retry command should succeed")
|
||||
// Now retry the DAG - it should run locally
|
||||
retrySpec := subCmdBuilder.Retry(dagWrapper.DAG, dagRunID, "")
|
||||
err = runtime.Start(coord.Context, retrySpec)
|
||||
require.NoError(t, err, "Retry command should succeed (process started)")
|
||||
|
||||
// Wait for the retry to be enqueued
|
||||
// Wait for completion
|
||||
require.Eventually(t, func() bool {
|
||||
queueItems, err := coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
|
||||
return err == nil && len(queueItems) == 1
|
||||
}, 2*time.Second, 50*time.Millisecond, "Retry should be enqueued")
|
||||
status, err := coord.DAGRunMgr.GetLatestStatus(coord.Context, dagWrapper.DAG)
|
||||
return err == nil && status.Status == core.Failed
|
||||
}, 5*time.Second, 100*time.Millisecond, "Retry should fail")
|
||||
|
||||
// Verify the retry was enqueued
|
||||
// Should NOT be enqueued
|
||||
queueItems, err = coord.QueueStore.ListByDAGName(coord.Context, dagWrapper.ProcGroup(), dagWrapper.Name)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queueItems, 1, "Retry should be enqueued once")
|
||||
|
||||
if len(queueItems) > 0 {
|
||||
data, err := queueItems[0].Data()
|
||||
require.NoError(t, err, "Should be able to get queue item data")
|
||||
require.Equal(t, dagRunID, data.ID, "Should have same DAG run ID")
|
||||
t.Logf("Retry enqueued: dag=%s runId=%s", data.Name, data.ID)
|
||||
}
|
||||
require.Len(t, queueItems, 0, "Retry should NOT be enqueued (dagu retry runs locally)")
|
||||
}
|
||||
|
||||
@ -1,42 +1,46 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGitHubActionsExecutor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Skip("skip")
|
||||
t.Run("BasicExecution", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
th := test.Setup(t)
|
||||
dag := th.DAG(t, `steps:
|
||||
dag := th.DAG(t, `
|
||||
workingDir: `+tmpDir+`
|
||||
steps:
|
||||
- name: test-action
|
||||
command: actions/hello-world-javascript-action@main
|
||||
command: actions/checkout@v4
|
||||
executor:
|
||||
type: github_action
|
||||
config:
|
||||
runner: node:24-bookworm
|
||||
runner: node:25-bookworm
|
||||
params:
|
||||
who-to-greet: "Morning"
|
||||
output: ACTION_OUTPUT
|
||||
repository: dagu-org/dagu
|
||||
sparse-checkout: README.md
|
||||
`)
|
||||
|
||||
// Verify git is available
|
||||
_, err := exec.LookPath("git")
|
||||
require.NoError(t, err, "git is required for this test but not found in PATH")
|
||||
|
||||
// Initialize git repo in the temp dir to satisfy act requirements
|
||||
cmd := exec.Command("git", "init", dag.WorkingDir)
|
||||
require.NoError(t, cmd.Run(), "failed to init git repo")
|
||||
|
||||
agent := dag.Agent()
|
||||
|
||||
agent.RunSuccess(t)
|
||||
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
// Verify that container output was captured to stdout
|
||||
// The hello-world action should log "Hello, Morning!" to console
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"ACTION_OUTPUT": []test.Contains{
|
||||
"Hello, Morning!",
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -2,11 +2,13 @@ package integration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/execution"
|
||||
"github.com/dagu-org/dagu/internal/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -255,3 +257,405 @@ steps:
|
||||
require.NotNil(t, status.OnCancel, "abort handler should have been executed")
|
||||
require.Equal(t, core.NodeSucceeded, status.OnCancel.Status)
|
||||
}
|
||||
|
||||
// TestHandlerOn_EnvironmentVariables tests that special environment variables
|
||||
// are accessible from each handler type.
|
||||
//
|
||||
// Environment variables availability by handler:
|
||||
//
|
||||
// | Variable | Init | Success | Failure | Cancel | Exit |
|
||||
// |----------------------------|--------|-----------|------------|---------|-----------|
|
||||
// | DAG_NAME | ✓ | ✓ | ✓ | ✓ | ✓ |
|
||||
// | DAG_RUN_ID | ✓ | ✓ | ✓ | ✓ | ✓ |
|
||||
// | DAG_RUN_LOG_FILE | ✓ | ✓ | ✓ | ✓ | ✓ |
|
||||
// | DAG_RUN_STEP_NAME | onInit | onSuccess | onFailure | onCancel| onExit |
|
||||
// | DAG_RUN_STATUS | running| succeeded | failed | aborted | succeeded/failed |
|
||||
// | DAG_RUN_STEP_STDOUT_FILE | ✗ | ✗ | ✗ | ✗ | ✗ |
|
||||
// | DAG_RUN_STEP_STDERR_FILE | ✗ | ✗ | ✗ | ✗ | ✗ |
|
||||
//
|
||||
// Note: DAG_RUN_STATUS in init handler is "running" because the DAG run has started
|
||||
// but steps haven't executed yet. DAG_RUN_STEP_STDOUT_FILE and DAG_RUN_STEP_STDERR_FILE
|
||||
// are not set for handlers because node.SetupEnv is not called during handler execution.
|
||||
func TestHandlerOn_EnvironmentVariables(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Helper to extract value from "KEY=value" format
|
||||
extractValue := func(output string) string {
|
||||
if idx := strings.Index(output, "="); idx != -1 {
|
||||
return output[idx+1:]
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
t.Run("InitHandler_BaseEnvVars", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
// Test that basic env vars (DAG_NAME, DAG_RUN_ID, DAG_RUN_LOG_FILE, DAG_RUN_STEP_NAME)
|
||||
// are available in the init handler
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
init:
|
||||
command: |
|
||||
echo "name:${DAG_NAME}|runid:${DAG_RUN_ID}|logfile:${DAG_RUN_LOG_FILE}|stepname:${DAG_RUN_STEP_NAME}"
|
||||
output: INIT_ENV_OUTPUT
|
||||
|
||||
steps:
|
||||
- name: step1
|
||||
command: "true"
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnInit, "init handler should have been executed")
|
||||
require.Equal(t, core.NodeSucceeded, status.OnInit.Status)
|
||||
require.NotNil(t, status.OnInit.OutputVariables, "init handler should have output variables")
|
||||
|
||||
output, ok := status.OnInit.OutputVariables.Load("INIT_ENV_OUTPUT")
|
||||
require.True(t, ok, "INIT_ENV_OUTPUT should be set")
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// Verify DAG_NAME is set and non-empty
|
||||
assert.Contains(t, outputStr, "name:", "output should contain name prefix")
|
||||
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
|
||||
|
||||
// Verify DAG_RUN_ID is set (UUID format)
|
||||
assert.Contains(t, outputStr, "runid:", "output should contain runid prefix")
|
||||
assert.NotContains(t, outputStr, "runid:|", "DAG_RUN_ID should not be empty")
|
||||
|
||||
// Verify DAG_RUN_LOG_FILE is set and contains .log
|
||||
assert.Contains(t, outputStr, "logfile:", "output should contain logfile prefix")
|
||||
assert.Contains(t, outputStr, ".log", "DAG_RUN_LOG_FILE should contain .log")
|
||||
|
||||
// Verify DAG_RUN_STEP_NAME is set to "onInit"
|
||||
assert.Contains(t, outputStr, "stepname:onInit", "DAG_RUN_STEP_NAME should be 'onInit'")
|
||||
})
|
||||
|
||||
t.Run("InitHandler_DAGRunStatus_IsRunning", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
// DAG_RUN_STATUS in init handler is "running" because the DAG run has started
|
||||
// but steps haven't completed yet. This value is technically correct but not
|
||||
// as useful as having the final status (which isn't known yet).
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
init:
|
||||
command: echo "${DAG_RUN_STATUS}"
|
||||
output: INIT_STATUS
|
||||
|
||||
steps:
|
||||
- name: step1
|
||||
command: "true"
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnInit)
|
||||
require.NotNil(t, status.OnInit.OutputVariables)
|
||||
|
||||
output, ok := status.OnInit.OutputVariables.Load("INIT_STATUS")
|
||||
require.True(t, ok)
|
||||
|
||||
// DAG_RUN_STATUS is "running" for init handler because steps haven't completed
|
||||
outputStr := extractValue(output.(string))
|
||||
assert.Equal(t, "running", outputStr, "DAG_RUN_STATUS should be 'running' in init handler")
|
||||
})
|
||||
|
||||
t.Run("SuccessHandler_AllEnvVars", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
success:
|
||||
command: |
|
||||
echo "name:${DAG_NAME}|status:${DAG_RUN_STATUS}|stepname:${DAG_RUN_STEP_NAME}"
|
||||
output: SUCCESS_ENV_OUTPUT
|
||||
|
||||
steps:
|
||||
- name: step1
|
||||
command: "true"
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnSuccess, "success handler should have been executed")
|
||||
require.Equal(t, core.NodeSucceeded, status.OnSuccess.Status)
|
||||
require.NotNil(t, status.OnSuccess.OutputVariables)
|
||||
|
||||
output, ok := status.OnSuccess.OutputVariables.Load("SUCCESS_ENV_OUTPUT")
|
||||
require.True(t, ok)
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// Verify DAG_NAME is set
|
||||
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
|
||||
|
||||
// Verify DAG_RUN_STATUS is "succeeded"
|
||||
assert.Contains(t, outputStr, "status:succeeded", "DAG_RUN_STATUS should be 'succeeded'")
|
||||
|
||||
// Verify DAG_RUN_STEP_NAME is "onSuccess"
|
||||
assert.Contains(t, outputStr, "stepname:onSuccess", "DAG_RUN_STEP_NAME should be 'onSuccess'")
|
||||
})
|
||||
|
||||
t.Run("FailureHandler_AllEnvVars", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
failure:
|
||||
command: |
|
||||
echo "name:${DAG_NAME}|status:${DAG_RUN_STATUS}|stepname:${DAG_RUN_STEP_NAME}"
|
||||
output: FAILURE_ENV_OUTPUT
|
||||
|
||||
steps:
|
||||
- name: failing-step
|
||||
command: exit 1
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunError(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnFailure, "failure handler should have been executed")
|
||||
require.Equal(t, core.NodeSucceeded, status.OnFailure.Status)
|
||||
require.NotNil(t, status.OnFailure.OutputVariables)
|
||||
|
||||
output, ok := status.OnFailure.OutputVariables.Load("FAILURE_ENV_OUTPUT")
|
||||
require.True(t, ok)
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// Verify DAG_NAME is set
|
||||
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
|
||||
|
||||
// Verify DAG_RUN_STATUS is "failed"
|
||||
assert.Contains(t, outputStr, "status:failed", "DAG_RUN_STATUS should be 'failed'")
|
||||
|
||||
// Verify DAG_RUN_STEP_NAME is "onFailure"
|
||||
assert.Contains(t, outputStr, "stepname:onFailure", "DAG_RUN_STEP_NAME should be 'onFailure'")
|
||||
})
|
||||
|
||||
t.Run("ExitHandler_AllEnvVars_OnSuccess", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
exit:
|
||||
command: |
|
||||
echo "name:${DAG_NAME}|status:${DAG_RUN_STATUS}|stepname:${DAG_RUN_STEP_NAME}"
|
||||
output: EXIT_ENV_OUTPUT
|
||||
|
||||
steps:
|
||||
- name: step1
|
||||
command: "true"
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnExit, "exit handler should have been executed")
|
||||
require.Equal(t, core.NodeSucceeded, status.OnExit.Status)
|
||||
require.NotNil(t, status.OnExit.OutputVariables)
|
||||
|
||||
output, ok := status.OnExit.OutputVariables.Load("EXIT_ENV_OUTPUT")
|
||||
require.True(t, ok)
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// Verify DAG_NAME is set
|
||||
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
|
||||
|
||||
// Verify DAG_RUN_STATUS is "succeeded" (exit runs after success)
|
||||
assert.Contains(t, outputStr, "status:succeeded", "DAG_RUN_STATUS should be 'succeeded'")
|
||||
|
||||
// Verify DAG_RUN_STEP_NAME is "onExit"
|
||||
assert.Contains(t, outputStr, "stepname:onExit", "DAG_RUN_STEP_NAME should be 'onExit'")
|
||||
})
|
||||
|
||||
t.Run("ExitHandler_AllEnvVars_OnFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
exit:
|
||||
command: |
|
||||
echo "status:${DAG_RUN_STATUS}"
|
||||
output: EXIT_ENV_OUTPUT
|
||||
|
||||
steps:
|
||||
- name: failing-step
|
||||
command: exit 1
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunError(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnExit, "exit handler should have been executed")
|
||||
require.Equal(t, core.NodeSucceeded, status.OnExit.Status)
|
||||
require.NotNil(t, status.OnExit.OutputVariables)
|
||||
|
||||
output, ok := status.OnExit.OutputVariables.Load("EXIT_ENV_OUTPUT")
|
||||
require.True(t, ok)
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// Verify DAG_RUN_STATUS is "failed" (exit runs after failure)
|
||||
assert.Contains(t, outputStr, "status:failed", "DAG_RUN_STATUS should be 'failed'")
|
||||
})
|
||||
|
||||
t.Run("CancelHandler_AllEnvVars", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
abort:
|
||||
command: |
|
||||
echo "name:${DAG_NAME}|status:${DAG_RUN_STATUS}|stepname:${DAG_RUN_STEP_NAME}"
|
||||
output: CANCEL_ENV_OUTPUT
|
||||
|
||||
steps:
|
||||
- name: long-running
|
||||
command: sleep 10
|
||||
`)
|
||||
dagAgent := dag.Agent()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = dagAgent.Run(th.Context)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
dag.AssertLatestStatus(t, core.Running)
|
||||
dagAgent.Abort()
|
||||
<-done
|
||||
|
||||
status := dagAgent.Status(th.Context)
|
||||
require.NotNil(t, status.OnCancel, "cancel handler should have been executed")
|
||||
require.Equal(t, core.NodeSucceeded, status.OnCancel.Status)
|
||||
require.NotNil(t, status.OnCancel.OutputVariables)
|
||||
|
||||
output, ok := status.OnCancel.OutputVariables.Load("CANCEL_ENV_OUTPUT")
|
||||
require.True(t, ok)
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// Verify DAG_NAME is set
|
||||
assert.NotContains(t, outputStr, "name:|", "DAG_NAME should not be empty")
|
||||
|
||||
// Verify DAG_RUN_STATUS is "aborted"
|
||||
assert.Contains(t, outputStr, "status:aborted", "DAG_RUN_STATUS should be 'aborted'")
|
||||
|
||||
// Verify DAG_RUN_STEP_NAME is "onCancel"
|
||||
assert.Contains(t, outputStr, "stepname:onCancel", "DAG_RUN_STEP_NAME should be 'onCancel'")
|
||||
})
|
||||
|
||||
t.Run("StepOutputVars_NotAvailableInHandlers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
// DAG_RUN_STEP_STDOUT_FILE and DAG_RUN_STEP_STDERR_FILE are NOT set
|
||||
// for handlers because node.SetupEnv is not called for handler execution
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
success:
|
||||
command: |
|
||||
echo "stdout:${DAG_RUN_STEP_STDOUT_FILE:-UNSET}|stderr:${DAG_RUN_STEP_STDERR_FILE:-UNSET}"
|
||||
output: HANDLER_STEP_FILES
|
||||
|
||||
steps:
|
||||
- name: step1
|
||||
command: "true"
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnSuccess)
|
||||
require.NotNil(t, status.OnSuccess.OutputVariables)
|
||||
|
||||
output, ok := status.OnSuccess.OutputVariables.Load("HANDLER_STEP_FILES")
|
||||
require.True(t, ok)
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// These should be unset/empty in handlers
|
||||
assert.Contains(t, outputStr, "stdout:UNSET", "DAG_RUN_STEP_STDOUT_FILE should not be set in handler")
|
||||
assert.Contains(t, outputStr, "stderr:UNSET", "DAG_RUN_STEP_STDERR_FILE should not be set in handler")
|
||||
})
|
||||
|
||||
t.Run("Handlers_CanAccessStepOutputVariables", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
// Handlers can access output variables from steps that have completed
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
success:
|
||||
command: |
|
||||
echo "step_output:${STEP_OUTPUT}"
|
||||
output: SUCCESS_WITH_STEP_OUTPUT
|
||||
|
||||
steps:
|
||||
- name: producer
|
||||
command: echo "produced_value"
|
||||
output: STEP_OUTPUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnSuccess)
|
||||
require.NotNil(t, status.OnSuccess.OutputVariables)
|
||||
|
||||
output, ok := status.OnSuccess.OutputVariables.Load("SUCCESS_WITH_STEP_OUTPUT")
|
||||
require.True(t, ok)
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// Success handler should be able to access step output
|
||||
assert.Contains(t, outputStr, "step_output:produced_value", "handler should access step output")
|
||||
})
|
||||
|
||||
t.Run("InitHandler_CannotAccessStepOutputVariables", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
th := test.Setup(t)
|
||||
|
||||
// Init handler runs BEFORE steps, so it cannot access step outputs
|
||||
dag := th.DAG(t, `
|
||||
handlerOn:
|
||||
init:
|
||||
command: |
|
||||
echo "step_output:${STEP_OUTPUT:-NOT_YET_AVAILABLE}"
|
||||
output: INIT_STEP_ACCESS
|
||||
|
||||
steps:
|
||||
- name: producer
|
||||
command: echo "produced_value"
|
||||
output: STEP_OUTPUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status := agent.Status(th.Context)
|
||||
require.NotNil(t, status.OnInit)
|
||||
require.NotNil(t, status.OnInit.OutputVariables)
|
||||
|
||||
output, ok := status.OnInit.OutputVariables.Load("INIT_STEP_ACCESS")
|
||||
require.True(t, ok)
|
||||
|
||||
outputStr := extractValue(output.(string))
|
||||
|
||||
// Init handler cannot access step output (steps haven't run yet)
|
||||
assert.Contains(t, outputStr, "step_output:NOT_YET_AVAILABLE",
|
||||
"init handler should not access step outputs")
|
||||
})
|
||||
}
|
||||
|
||||
@ -1,463 +0,0 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLocalDAGExecution(t *testing.T) {
|
||||
t.Run("SimpleLocalDAG", func(t *testing.T) {
|
||||
// Create a DAG with local sub DAGs using separator
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: run-local-child
|
||||
call: local-child
|
||||
params: "NAME=World"
|
||||
output: SUB_RESULT
|
||||
|
||||
- echo "Child said ${SUB_RESULT.outputs.GREETING}"
|
||||
|
||||
---
|
||||
|
||||
name: local-child
|
||||
params:
|
||||
- NAME
|
||||
steps:
|
||||
- command: echo "Hello, ${NAME}!"
|
||||
output: GREETING
|
||||
|
||||
- echo "Greeting was ${GREETING}"
|
||||
`
|
||||
// Setup test helper
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
// Run the DAG
|
||||
agent := testDAG.Agent()
|
||||
require.NoError(t, agent.Run(agent.Context))
|
||||
|
||||
// Verify successful completion
|
||||
testDAG.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
// Get the full run status
|
||||
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the first step (run-local-child) completed successfully
|
||||
// Note: The sub DAG's output is not directly visible in the parent's stdout
|
||||
require.Len(t, dagRunStatus.Nodes, 2)
|
||||
require.Equal(t, "run-local-child", dagRunStatus.Nodes[0].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
|
||||
|
||||
// Verify the second step output
|
||||
logContent, err := os.ReadFile(dagRunStatus.Nodes[1].Stdout)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(logContent), "Child said Hello, World!")
|
||||
})
|
||||
|
||||
t.Run("ParallelLocalDAGExecution", func(t *testing.T) {
|
||||
// Create a DAG with parallel execution of local DAGs
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: parallel-tasks
|
||||
call: worker-dag
|
||||
parallel:
|
||||
items:
|
||||
- TASK_ID=1 TASK_NAME=alpha
|
||||
- TASK_ID=2 TASK_NAME=beta
|
||||
- TASK_ID=3 TASK_NAME=gamma
|
||||
maxConcurrent: 2
|
||||
|
||||
---
|
||||
|
||||
name: worker-dag
|
||||
params:
|
||||
- TASK_ID
|
||||
- TASK_NAME
|
||||
steps:
|
||||
- echo "Starting task ${TASK_ID} - ${TASK_NAME}"
|
||||
- echo "Processing ${TASK_NAME} with ID ${TASK_ID}"
|
||||
- echo "Completed ${TASK_NAME}"
|
||||
`
|
||||
// Setup test helper
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
// Run the DAG
|
||||
agent := testDAG.Agent()
|
||||
require.NoError(t, agent.Run(agent.Context))
|
||||
|
||||
// Verify successful completion
|
||||
testDAG.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
// Get the full run status
|
||||
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For parallel execution, we should have one step that ran multiple instances
|
||||
require.Len(t, dagRunStatus.Nodes, 1)
|
||||
require.Equal(t, "parallel-tasks", dagRunStatus.Nodes[0].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
|
||||
})
|
||||
|
||||
t.Run("NestedLocalDAGs", func(t *testing.T) {
|
||||
// Test that nested local DAGs beyond 1 level are not supported
|
||||
// This should fail because middle-dag tries to run leaf-dag, but leaf-dag
|
||||
// is not visible to middle-dag (only to root-dag)
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: run-middle-dag
|
||||
call: middle-dag
|
||||
params: "ROOT_PARAM=FromRoot"
|
||||
|
||||
---
|
||||
|
||||
name: middle-dag
|
||||
params:
|
||||
- ROOT_PARAM
|
||||
steps:
|
||||
- command: echo "Received ${ROOT_PARAM}"
|
||||
output: MIDDLE_OUTPUT
|
||||
|
||||
- name: run-leaf-dag
|
||||
call: leaf-dag
|
||||
params: "MIDDLE_PARAM=${MIDDLE_OUTPUT} LEAF_PARAM=FromMiddle"
|
||||
|
||||
---
|
||||
|
||||
name: leaf-dag
|
||||
params:
|
||||
- MIDDLE_PARAM
|
||||
- LEAF_PARAM
|
||||
steps:
|
||||
- command: |
|
||||
echo "Middle: ${MIDDLE_PARAM}, Leaf: ${LEAF_PARAM}"
|
||||
`
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
agent := testDAG.Agent()
|
||||
err := agent.Run(agent.Context)
|
||||
// The root DAG execution will fail because middle-dag fails
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed")
|
||||
|
||||
// This should fail because middle-dag cannot see leaf-dag
|
||||
testDAG.AssertLatestStatus(t, core.Failed)
|
||||
|
||||
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Root DAG should have one step that tried to run middle-dag
|
||||
require.Len(t, dagRunStatus.Nodes, 1)
|
||||
require.Equal(t, "run-middle-dag", dagRunStatus.Nodes[0].Step.Name)
|
||||
require.Equal(t, core.NodeFailed, dagRunStatus.Nodes[0].Status)
|
||||
})
|
||||
|
||||
t.Run("LocalDAGWithConditionalExecution", func(t *testing.T) {
|
||||
// Test conditional execution with local DAGs
|
||||
yamlContent := `
|
||||
env:
|
||||
- ENVIRONMENT: production
|
||||
steps:
|
||||
- name: check-env
|
||||
command: echo "${ENVIRONMENT}"
|
||||
output: ENV_TYPE
|
||||
|
||||
- name: run-prod-dag
|
||||
call: production-dag
|
||||
preconditions:
|
||||
- condition: "${ENV_TYPE}"
|
||||
expected: "production"
|
||||
|
||||
- name: run-dev-dag
|
||||
call: development-dag
|
||||
preconditions:
|
||||
- condition: "${ENV_TYPE}"
|
||||
expected: "development"
|
||||
|
||||
---
|
||||
|
||||
name: production-dag
|
||||
steps:
|
||||
- echo "Deploying to production"
|
||||
- echo "Verifying production deployment"
|
||||
|
||||
---
|
||||
|
||||
name: development-dag
|
||||
steps:
|
||||
- echo "Building for development"
|
||||
- echo "Running development tests"
|
||||
`
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
agent := testDAG.Agent()
|
||||
require.NoError(t, agent.Run(agent.Context))
|
||||
|
||||
testDAG.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have 3 steps: check-env, run-prod-dag, run-dev-dag
|
||||
require.Len(t, dagRunStatus.Nodes, 3)
|
||||
|
||||
// Check environment step
|
||||
require.Equal(t, "check-env", dagRunStatus.Nodes[0].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
|
||||
|
||||
// Production DAG should run
|
||||
require.Equal(t, "run-prod-dag", dagRunStatus.Nodes[1].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[1].Status)
|
||||
|
||||
// Development DAG should be skipped
|
||||
require.Equal(t, "run-dev-dag", dagRunStatus.Nodes[2].Step.Name)
|
||||
require.Equal(t, core.NodeSkipped, dagRunStatus.Nodes[2].Status)
|
||||
})
|
||||
|
||||
t.Run("LocalDAGWithOutputPassing", func(t *testing.T) {
|
||||
// Test passing outputs between local DAGs
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: generate-data
|
||||
call: generator-dag
|
||||
output: GEN_OUTPUT
|
||||
|
||||
- name: process-data
|
||||
call: processor-dag
|
||||
params: "INPUT_DATA=${GEN_OUTPUT.outputs.DATA}"
|
||||
|
||||
---
|
||||
|
||||
name: generator-dag
|
||||
steps:
|
||||
- command: echo "test-value-42"
|
||||
output: DATA
|
||||
|
||||
---
|
||||
|
||||
name: processor-dag
|
||||
params:
|
||||
- INPUT_DATA
|
||||
steps:
|
||||
- command: echo "Processing ${INPUT_DATA}"
|
||||
output: RESULT
|
||||
|
||||
- command: |
|
||||
echo "Validated: ${RESULT}"
|
||||
`
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
agent := testDAG.Agent()
|
||||
require.NoError(t, agent.Run(agent.Context))
|
||||
|
||||
testDAG.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have 2 steps
|
||||
require.Len(t, dagRunStatus.Nodes, 2)
|
||||
|
||||
// First step generates data
|
||||
require.Equal(t, "generate-data", dagRunStatus.Nodes[0].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
|
||||
|
||||
// Second step processes data
|
||||
require.Equal(t, "process-data", dagRunStatus.Nodes[1].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[1].Status)
|
||||
})
|
||||
|
||||
t.Run("LocalDAGReferencesNonExistent", func(t *testing.T) {
|
||||
// Test error when referencing non-existent local DAG
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: run-missing-dag
|
||||
call: non-existent-dag
|
||||
|
||||
---
|
||||
|
||||
name: some-other-dag
|
||||
steps:
|
||||
- echo "test"
|
||||
`
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
agent := testDAG.Agent()
|
||||
err := agent.Run(agent.Context)
|
||||
// The agent will return an error when a step fails
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "non-existent-dag")
|
||||
|
||||
// Check that the DAG failed
|
||||
testDAG.AssertLatestStatus(t, core.Failed)
|
||||
|
||||
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have one step that failed
|
||||
require.Len(t, dagRunStatus.Nodes, 1)
|
||||
require.Equal(t, "run-missing-dag", dagRunStatus.Nodes[0].Step.Name)
|
||||
require.Equal(t, core.NodeFailed, dagRunStatus.Nodes[0].Status)
|
||||
})
|
||||
|
||||
t.Run("LocalDAGWithComplexDependencies", func(t *testing.T) {
|
||||
// Test complex dependencies between local DAGs
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: setup
|
||||
command: echo "Setting up"
|
||||
output: SETUP_STATUS
|
||||
|
||||
- name: task1
|
||||
call: task-dag
|
||||
params: "TASK_NAME=Task1 SETUP=${SETUP_STATUS}"
|
||||
output: TASK1_RESULT
|
||||
|
||||
- name: task2
|
||||
call: task-dag
|
||||
params: "TASK_NAME=Task2 SETUP=${SETUP_STATUS}"
|
||||
output: TASK2_RESULT
|
||||
|
||||
- name: combine
|
||||
command: |
|
||||
echo "Combining ${TASK1_RESULT.outputs.RESULT} and ${TASK2_RESULT.outputs.RESULT}"
|
||||
depends:
|
||||
- task1
|
||||
- task2
|
||||
|
||||
---
|
||||
|
||||
name: task-dag
|
||||
params:
|
||||
- TASK_NAME
|
||||
- SETUP
|
||||
steps:
|
||||
- command: echo "${TASK_NAME} processing with ${SETUP}"
|
||||
output: RESULT
|
||||
`
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
agent := testDAG.Agent()
|
||||
require.NoError(t, agent.Run(agent.Context))
|
||||
|
||||
testDAG.AssertLatestStatus(t, core.Succeeded)
|
||||
|
||||
dagRunStatus, err := th.DAGRunMgr.GetLatestStatus(th.Context, testDAG.DAG)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have 4 steps: setup, task1, task2, combine
|
||||
require.Len(t, dagRunStatus.Nodes, 4)
|
||||
|
||||
// Verify each step
|
||||
require.Equal(t, "setup", dagRunStatus.Nodes[0].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[0].Status)
|
||||
|
||||
require.Equal(t, "task1", dagRunStatus.Nodes[1].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[1].Status)
|
||||
|
||||
require.Equal(t, "task2", dagRunStatus.Nodes[2].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[2].Status)
|
||||
|
||||
require.Equal(t, "combine", dagRunStatus.Nodes[3].Step.Name)
|
||||
require.Equal(t, core.NodeSucceeded, dagRunStatus.Nodes[3].Status)
|
||||
|
||||
// Verify the combine step output
|
||||
logContent, err := os.ReadFile(dagRunStatus.Nodes[3].Stdout)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(logContent), "Combining")
|
||||
require.Contains(t, string(logContent), "Task1 processing with Setting up")
|
||||
require.Contains(t, string(logContent), "Task2 processing with Setting up")
|
||||
})
|
||||
t.Run("PartialSuccessParallel", func(t *testing.T) {
|
||||
// Create a DAG with parallel execution of local DAGs
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: parallel-tasks
|
||||
call: worker-dag
|
||||
parallel:
|
||||
items:
|
||||
- TASK_ID=1 TASK_NAME=alpha
|
||||
---
|
||||
|
||||
name: worker-dag
|
||||
params:
|
||||
- TASK_ID
|
||||
- TASK_NAME
|
||||
steps:
|
||||
- command: exit 1
|
||||
continueOn:
|
||||
failure: true
|
||||
|
||||
- exit 0
|
||||
`
|
||||
// Setup test helper
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
// Run the DAG
|
||||
agent := testDAG.Agent()
|
||||
require.NoError(t, agent.Run(agent.Context))
|
||||
|
||||
// Verify successful completion
|
||||
testDAG.AssertLatestStatus(t, core.PartiallySucceeded)
|
||||
})
|
||||
|
||||
t.Run("PartialSuccessSubDAG", func(t *testing.T) {
|
||||
// Create a DAG with parallel execution of local DAGs
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: parallel-tasks
|
||||
call: worker-dag
|
||||
---
|
||||
|
||||
name: worker-dag
|
||||
params:
|
||||
- TASK_ID
|
||||
- TASK_NAME
|
||||
steps:
|
||||
- command: exit 1
|
||||
continueOn:
|
||||
failure: true
|
||||
|
||||
- exit 0
|
||||
`
|
||||
// Setup test helper
|
||||
th := test.Setup(t)
|
||||
|
||||
// Load the DAG using helper
|
||||
testDAG := th.DAG(t, yamlContent)
|
||||
|
||||
// Run the DAG
|
||||
agent := testDAG.Agent()
|
||||
require.NoError(t, agent.Run(agent.Context))
|
||||
|
||||
// Verify successful completion
|
||||
testDAG.AssertLatestStatus(t, core.PartiallySucceeded)
|
||||
})
|
||||
}
|
||||
468
internal/integration/multi_command_test.go
Normal file
468
internal/integration/multi_command_test.go
Normal file
@ -0,0 +1,468 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/spec"
|
||||
"github.com/dagu-org/dagu/internal/test"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMultipleCommands_Shell(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping Unix shell tests on Windows")
|
||||
}
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
|
||||
t.Run("TwoCommands", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
command:
|
||||
- echo hello
|
||||
- echo world
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "hello\nworld",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ThreeCommandsWithArgs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
command:
|
||||
- echo "first command"
|
||||
- echo "second command"
|
||||
- echo "third command"
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "first command\nsecond command\nthird command",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CommandsWithEnvVars", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
env:
|
||||
- MY_VAR: hello
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
command:
|
||||
- echo $MY_VAR
|
||||
- echo "${MY_VAR} world"
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "hello\nhello world",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("FirstCommandFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
command:
|
||||
- "false"
|
||||
- echo "should not run"
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunError(t)
|
||||
dag.AssertLatestStatus(t, core.Failed)
|
||||
})
|
||||
|
||||
t.Run("SecondCommandFails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
command:
|
||||
- echo "first runs"
|
||||
- "false"
|
||||
- echo "should not run"
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunError(t)
|
||||
dag.AssertLatestStatus(t, core.Failed)
|
||||
})
|
||||
|
||||
t.Run("CommandsWithPipes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
shell: /bin/bash
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
command:
|
||||
- echo "hello world" | tr 'h' 'H'
|
||||
- echo "foo bar" | tr 'f' 'F'
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "Hello world\nFoo bar",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CommandsWithWorkingDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
workingDir: /tmp
|
||||
command:
|
||||
- pwd
|
||||
- echo "done"
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "/tmp\ndone",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("DependsOnPreviousStep", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "step1"
|
||||
output: STEP1_OUT
|
||||
- name: step2
|
||||
depends:
|
||||
- step1
|
||||
command:
|
||||
- echo "from step2"
|
||||
- echo "done"
|
||||
output: STEP2_OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"STEP1_OUT": "step1",
|
||||
"STEP2_OUT": "from step2\ndone",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleCommands_Docker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testImage = "alpine:3"
|
||||
|
||||
t.Run("TwoCommandsInContainer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
// Use startup: command to keep container running for multiple commands
|
||||
dag := th.DAG(t, fmt.Sprintf(`
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
container:
|
||||
image: %s
|
||||
startup: command
|
||||
command: ["sh", "-c", "while true; do sleep 3600; done"]
|
||||
command:
|
||||
- echo hello
|
||||
- echo world
|
||||
output: OUT
|
||||
`, testImage))
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "hello\nworld",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CommandsWithEnvInContainer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
// Use startup: command to keep container running for multiple commands
|
||||
dag := th.DAG(t, fmt.Sprintf(`
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
container:
|
||||
image: %s
|
||||
startup: command
|
||||
command: ["sh", "-c", "while true; do sleep 3600; done"]
|
||||
env:
|
||||
- MY_VAR=hello
|
||||
command:
|
||||
- printenv MY_VAR
|
||||
- echo "done"
|
||||
output: OUT
|
||||
`, testImage))
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "hello\ndone",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("FirstCommandFailsInContainer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
// Use startup: command to keep container running for multiple commands
|
||||
dag := th.DAG(t, fmt.Sprintf(`
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
container:
|
||||
image: %s
|
||||
startup: command
|
||||
command: ["sh", "-c", "while true; do sleep 3600; done"]
|
||||
command:
|
||||
- "false"
|
||||
- echo "should not run"
|
||||
output: OUT
|
||||
`, testImage))
|
||||
agent := dag.Agent()
|
||||
agent.RunError(t)
|
||||
dag.AssertLatestStatus(t, core.Failed)
|
||||
})
|
||||
|
||||
t.Run("SecondCommandFailsInContainer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
// Use startup: command to keep container running for multiple commands
|
||||
dag := th.DAG(t, fmt.Sprintf(`
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
container:
|
||||
image: %s
|
||||
startup: command
|
||||
command: ["sh", "-c", "while true; do sleep 3600; done"]
|
||||
command:
|
||||
- echo "first runs"
|
||||
- "false"
|
||||
- echo "should not run"
|
||||
output: OUT
|
||||
`, testImage))
|
||||
agent := dag.Agent()
|
||||
agent.RunError(t)
|
||||
dag.AssertLatestStatus(t, core.Failed)
|
||||
})
|
||||
|
||||
t.Run("DAGLevelContainerWithMultipleCommands", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
dag := th.DAG(t, fmt.Sprintf(`
|
||||
container:
|
||||
image: %s
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
command:
|
||||
- echo hello
|
||||
- echo world
|
||||
output: OUT
|
||||
`, testImage))
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "hello\nworld",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("MultipleStepsWithMultipleCommands", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
dag := th.DAG(t, fmt.Sprintf(`
|
||||
container:
|
||||
image: %s
|
||||
steps:
|
||||
- name: step1
|
||||
command:
|
||||
- echo "step1-cmd1"
|
||||
- echo "step1-cmd2"
|
||||
output: STEP1_OUT
|
||||
- name: step2
|
||||
depends:
|
||||
- step1
|
||||
command:
|
||||
- echo "step2-cmd1"
|
||||
- echo "step2-cmd2"
|
||||
output: STEP2_OUT
|
||||
`, testImage))
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"STEP1_OUT": "step1-cmd1\nstep1-cmd2",
|
||||
"STEP2_OUT": "step2-cmd1\nstep2-cmd2",
|
||||
})
|
||||
})
|
||||
|
||||
// Test step-level container without startup:command - uses default keepalive mode
|
||||
t.Run("StepContainerWithDefaultKeepalive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
// No startup:command - should use default keepalive mode
|
||||
dag := th.DAG(t, fmt.Sprintf(`
|
||||
steps:
|
||||
- name: multi-cmd
|
||||
container:
|
||||
image: %s
|
||||
command:
|
||||
- echo hello
|
||||
- echo world
|
||||
output: OUT
|
||||
`, testImage))
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
dag.AssertLatestStatus(t, core.Succeeded)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "hello\nworld",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleCommands_Validation(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping Unix shell tests on Windows")
|
||||
}
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
|
||||
t.Run("JQExecutorRejectsMultipleCommands", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp file with the DAG content
|
||||
tempDir := t.TempDir()
|
||||
filename := fmt.Sprintf("%s.yaml", uuid.New().String())
|
||||
testFile := filepath.Join(tempDir, filename)
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: jq-multi
|
||||
executor: jq
|
||||
command:
|
||||
- ".foo"
|
||||
- ".bar"
|
||||
script: '{"foo": "bar"}'
|
||||
`
|
||||
err := os.WriteFile(testFile, []byte(yamlContent), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = spec.Load(th.Context, testFile)
|
||||
require.Error(t, err, "expected error for multiple commands with jq executor")
|
||||
require.Contains(t, err.Error(), "executor does not support multiple commands")
|
||||
})
|
||||
|
||||
t.Run("HTTPExecutorRejectsMultipleCommands", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a temp file with the DAG content
|
||||
tempDir := t.TempDir()
|
||||
filename := fmt.Sprintf("%s.yaml", uuid.New().String())
|
||||
testFile := filepath.Join(tempDir, filename)
|
||||
yamlContent := `
|
||||
steps:
|
||||
- name: http-multi
|
||||
executor: http
|
||||
command:
|
||||
- "GET https://example.com"
|
||||
- "POST https://example.com"
|
||||
`
|
||||
err := os.WriteFile(testFile, []byte(yamlContent), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = spec.Load(th.Context, testFile)
|
||||
require.Error(t, err, "expected error for multiple commands with http executor")
|
||||
require.Contains(t, err.Error(), "executor does not support multiple commands")
|
||||
})
|
||||
|
||||
t.Run("ShellExecutorAllowsMultipleCommands", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: shell-multi
|
||||
executor: shell
|
||||
command:
|
||||
- echo hello
|
||||
- echo world
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "hello\nworld",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CommandExecutorAllowsMultipleCommands", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: cmd-multi
|
||||
executor: command
|
||||
command:
|
||||
- echo hello
|
||||
- echo world
|
||||
output: OUT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
// Output is concatenated from all commands
|
||||
dag.AssertOutputs(t, map[string]any{
|
||||
"OUT": "hello\nworld",
|
||||
})
|
||||
})
|
||||
}
|
||||
439
internal/integration/outputs_collection_test.go
Normal file
439
internal/integration/outputs_collection_test.go
Normal file
@ -0,0 +1,439 @@
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/execution"
|
||||
"github.com/dagu-org/dagu/internal/persistence/filedagrun"
|
||||
"github.com/dagu-org/dagu/internal/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOutputsCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dagYAML string
|
||||
runFunc func(*testing.T, context.Context, *test.Agent)
|
||||
validateFunc func(*testing.T, execution.DAGRunStatus)
|
||||
validateOutputs func(*testing.T, map[string]string)
|
||||
}{
|
||||
{
|
||||
name: "SimpleStringOutput",
|
||||
dagYAML: `
|
||||
steps:
|
||||
- name: produce-output
|
||||
command: echo "RESULT=42"
|
||||
output: RESULT
|
||||
`,
|
||||
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
|
||||
agent.RunSuccess(t)
|
||||
},
|
||||
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
require.Len(t, status.Nodes, 1)
|
||||
require.Equal(t, core.NodeSucceeded, status.Nodes[0].Status)
|
||||
},
|
||||
validateOutputs: func(t *testing.T, outputs map[string]string) {
|
||||
require.NotNil(t, outputs)
|
||||
// Output value includes the KEY= prefix from command output
|
||||
assert.Equal(t, "RESULT=42", outputs["result"]) // SCREAMING_SNAKE to camelCase
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OutputWithCustomKey",
|
||||
dagYAML: `
|
||||
steps:
|
||||
- name: produce-output
|
||||
command: echo "MY_VALUE=hello world"
|
||||
output:
|
||||
name: MY_VALUE
|
||||
key: customKeyName
|
||||
`,
|
||||
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
|
||||
agent.RunSuccess(t)
|
||||
},
|
||||
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
},
|
||||
validateOutputs: func(t *testing.T, outputs map[string]string) {
|
||||
require.NotNil(t, outputs)
|
||||
// Value includes the original KEY= prefix
|
||||
assert.Equal(t, "MY_VALUE=hello world", outputs["customKeyName"])
|
||||
_, hasDefault := outputs["myValue"]
|
||||
assert.False(t, hasDefault, "should not have default key when custom key is specified")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OutputWithOmit",
|
||||
dagYAML: `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "VISIBLE=yes"
|
||||
output: VISIBLE
|
||||
|
||||
- name: step2
|
||||
command: echo "HIDDEN=secret"
|
||||
output:
|
||||
name: HIDDEN
|
||||
omit: true
|
||||
`,
|
||||
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
|
||||
agent.RunSuccess(t)
|
||||
},
|
||||
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
require.Len(t, status.Nodes, 2)
|
||||
},
|
||||
validateOutputs: func(t *testing.T, outputs map[string]string) {
|
||||
require.NotNil(t, outputs)
|
||||
assert.Equal(t, "VISIBLE=yes", outputs["visible"])
|
||||
_, hasHidden := outputs["hidden"]
|
||||
assert.False(t, hasHidden, "omitted output should not be in outputs.json")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MultipleStepsWithOutputs",
|
||||
dagYAML: `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "COUNT=10"
|
||||
output: COUNT
|
||||
|
||||
- name: step2
|
||||
command: echo "TOTAL=100"
|
||||
output: TOTAL
|
||||
|
||||
- name: step3
|
||||
command: echo "STATUS=completed"
|
||||
output: STATUS
|
||||
`,
|
||||
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
|
||||
agent.RunSuccess(t)
|
||||
},
|
||||
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
require.Len(t, status.Nodes, 3)
|
||||
},
|
||||
validateOutputs: func(t *testing.T, outputs map[string]string) {
|
||||
require.NotNil(t, outputs)
|
||||
assert.Len(t, outputs, 3)
|
||||
assert.Equal(t, "COUNT=10", outputs["count"])
|
||||
assert.Equal(t, "TOTAL=100", outputs["total"])
|
||||
assert.Equal(t, "STATUS=completed", outputs["status"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "LastOneWinsForDuplicateKeys",
|
||||
dagYAML: `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "VALUE=first"
|
||||
output: VALUE
|
||||
|
||||
- name: step2
|
||||
depends: [step1]
|
||||
command: echo "VALUE=second"
|
||||
output: VALUE
|
||||
`,
|
||||
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
|
||||
agent.RunSuccess(t)
|
||||
},
|
||||
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
},
|
||||
validateOutputs: func(t *testing.T, outputs map[string]string) {
|
||||
require.NotNil(t, outputs)
|
||||
// Last step wins
|
||||
assert.Equal(t, "VALUE=second", outputs["value"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NoOutputsProduced",
|
||||
dagYAML: `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "hello"
|
||||
`,
|
||||
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
|
||||
agent.RunSuccess(t)
|
||||
},
|
||||
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
},
|
||||
validateOutputs: func(t *testing.T, outputs map[string]string) {
|
||||
// No outputs.json should be created when no outputs
|
||||
assert.Nil(t, outputs)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OutputWithDollarPrefix",
|
||||
dagYAML: `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "MY_VAR=value123"
|
||||
output: $MY_VAR
|
||||
`,
|
||||
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
|
||||
agent.RunSuccess(t)
|
||||
},
|
||||
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
},
|
||||
validateOutputs: func(t *testing.T, outputs map[string]string) {
|
||||
require.NotNil(t, outputs)
|
||||
assert.Equal(t, "MY_VAR=value123", outputs["myVar"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MixedOutputConfigurations",
|
||||
dagYAML: `
|
||||
steps:
|
||||
- name: simple
|
||||
command: echo "SIMPLE_OUT=simple_value"
|
||||
output: SIMPLE_OUT
|
||||
|
||||
- name: with-key
|
||||
command: echo "KEYED=keyed_value"
|
||||
output:
|
||||
name: KEYED
|
||||
key: renamedKey
|
||||
|
||||
- name: omitted
|
||||
command: echo "SECRET=secret_value"
|
||||
output:
|
||||
name: SECRET
|
||||
omit: true
|
||||
`,
|
||||
runFunc: func(t *testing.T, _ context.Context, agent *test.Agent) {
|
||||
agent.RunSuccess(t)
|
||||
},
|
||||
validateFunc: func(t *testing.T, status execution.DAGRunStatus) {
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
},
|
||||
validateOutputs: func(t *testing.T, outputs map[string]string) {
|
||||
require.NotNil(t, outputs)
|
||||
assert.Len(t, outputs, 2) // simple + keyed, NOT secret
|
||||
assert.Equal(t, "SIMPLE_OUT=simple_value", outputs["simpleOut"])
|
||||
assert.Equal(t, "KEYED=keyed_value", outputs["renamedKey"])
|
||||
_, hasSecret := outputs["secret"]
|
||||
assert.False(t, hasSecret)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
dag := th.DAG(t, tt.dagYAML)
|
||||
agent := dag.Agent()
|
||||
|
||||
// Run the DAG
|
||||
tt.runFunc(t, agent.Context, agent)
|
||||
|
||||
// Validate DAG run status
|
||||
status, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
|
||||
require.NoError(t, err)
|
||||
tt.validateFunc(t, status)
|
||||
|
||||
// Read outputs.json if it exists
|
||||
outputs := readOutputsFile(t, th, dag.DAG)
|
||||
tt.validateOutputs(t, outputs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputsCollection_FailedDAG(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "BEFORE_FAIL=collected"
|
||||
output: BEFORE_FAIL
|
||||
|
||||
- name: step2
|
||||
depends: [step1]
|
||||
command: exit 1
|
||||
|
||||
- name: step3
|
||||
depends: [step2]
|
||||
command: echo "AFTER_FAIL=not_collected"
|
||||
output: AFTER_FAIL
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
|
||||
_ = agent.Run(agent.Context)
|
||||
|
||||
status, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, core.Failed, status.Status)
|
||||
|
||||
// Outputs from successful steps should still be collected
|
||||
outputs := readOutputsFile(t, th, dag.DAG)
|
||||
require.NotNil(t, outputs)
|
||||
assert.Equal(t, "BEFORE_FAIL=collected", outputs["beforeFail"])
|
||||
_, hasAfterFail := outputs["afterFail"]
|
||||
assert.False(t, hasAfterFail, "output from step after failure should not be collected")
|
||||
}
|
||||
|
||||
func TestOutputsCollection_CamelCaseConversion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
envVarName string
|
||||
expectedKey string
|
||||
expectedValue string
|
||||
}{
|
||||
{"SIMPLE", "simple", "SIMPLE=test_value"},
|
||||
{"TWO_WORDS", "twoWords", "TWO_WORDS=test_value"},
|
||||
{"MULTIPLE_WORD_NAME", "multipleWordName", "MULTIPLE_WORD_NAME=test_value"},
|
||||
{"ALREADY_CAMEL_Case", "alreadyCamelCase", "ALREADY_CAMEL_Case=test_value"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.envVarName, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "`+tt.envVarName+`=test_value"
|
||||
output: `+tt.envVarName+`
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
outputs := readOutputsFile(t, th, dag.DAG)
|
||||
require.NotNil(t, outputs)
|
||||
// Value includes the KEY= prefix from the original output
|
||||
assert.Equal(t, tt.expectedValue, outputs[tt.expectedKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutputsCollection_SecretsMasked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
|
||||
// Create a temporary secret file
|
||||
secretValue := "super-secret-api-token-xyz123"
|
||||
secretFile := th.TempFile(t, "secret.txt", []byte(secretValue))
|
||||
|
||||
dag := th.DAG(t, `
|
||||
secrets:
|
||||
- name: API_TOKEN
|
||||
provider: file
|
||||
key: `+secretFile+`
|
||||
|
||||
steps:
|
||||
- name: output-secret
|
||||
command: echo "TOKEN=${API_TOKEN}"
|
||||
output: TOKEN
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, core.Succeeded, status.Status)
|
||||
|
||||
// Read outputs.json
|
||||
outputs := readOutputsFile(t, th, dag.DAG)
|
||||
require.NotNil(t, outputs)
|
||||
|
||||
// The output value should contain the masked secret, not the actual value
|
||||
tokenOutput := outputs["token"]
|
||||
require.NotEmpty(t, tokenOutput)
|
||||
assert.NotContains(t, tokenOutput, secretValue, "secret value should be masked in outputs")
|
||||
assert.Contains(t, tokenOutput, "*******", "masked placeholder should appear in outputs")
|
||||
}
|
||||
|
||||
func TestOutputsCollection_MetadataIncluded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.Setup(t)
|
||||
dag := th.DAG(t, `
|
||||
steps:
|
||||
- name: step1
|
||||
command: echo "RESULT=42"
|
||||
output: RESULT
|
||||
`)
|
||||
agent := dag.Agent()
|
||||
agent.RunSuccess(t)
|
||||
|
||||
status, err := th.DAGRunMgr.GetLatestStatus(th.Context, dag.DAG)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read full outputs including metadata
|
||||
fullOutputs := readFullOutputsFile(t, th, dag.DAG)
|
||||
require.NotNil(t, fullOutputs)
|
||||
|
||||
// Validate metadata
|
||||
assert.Equal(t, dag.Name, fullOutputs.Metadata.DAGName)
|
||||
assert.Equal(t, status.DAGRunID, fullOutputs.Metadata.DAGRunID)
|
||||
assert.NotEmpty(t, fullOutputs.Metadata.AttemptID)
|
||||
assert.Equal(t, "succeeded", fullOutputs.Metadata.Status)
|
||||
assert.NotEmpty(t, fullOutputs.Metadata.CompletedAt)
|
||||
|
||||
// Validate outputs are still present
|
||||
assert.Equal(t, "RESULT=42", fullOutputs.Outputs["result"])
|
||||
}
|
||||
|
||||
// readOutputsFile reads the outputs.json file for a given DAG run
|
||||
// Returns just the outputs map for backward compatibility with existing tests
|
||||
func readOutputsFile(t *testing.T, th test.Helper, dag *core.DAG) map[string]string {
|
||||
t.Helper()
|
||||
|
||||
fullOutputs := readFullOutputsFile(t, th, dag)
|
||||
if fullOutputs == nil {
|
||||
return nil
|
||||
}
|
||||
return fullOutputs.Outputs
|
||||
}
|
||||
|
||||
// readFullOutputsFile reads the full outputs.json file including metadata
|
||||
func readFullOutputsFile(t *testing.T, th test.Helper, dag *core.DAG) *execution.DAGRunOutputs {
|
||||
t.Helper()
|
||||
|
||||
// Find the attempt directory
|
||||
dagRunsDir := th.Config.Paths.DAGRunsDir
|
||||
dagRunDir := filepath.Join(dagRunsDir, dag.Name, "dag-runs")
|
||||
|
||||
// Walk to find the outputs.json file
|
||||
var outputsPath string
|
||||
_ = filepath.Walk(dagRunDir, func(path string, info os.FileInfo, err error) error {
|
||||
require.NoError(t, err)
|
||||
if info.Name() == filedagrun.OutputsFile {
|
||||
outputsPath = path
|
||||
return filepath.SkipAll
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if outputsPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(outputsPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
var outputs execution.DAGRunOutputs
|
||||
require.NoError(t, json.Unmarshal(data, &outputs))
|
||||
|
||||
// Return nil if old format (no metadata)
|
||||
require.NotEmpty(t, outputs.Metadata.DAGRunID)
|
||||
return &outputs
|
||||
}
|
||||
@ -8,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/dagu-org/dagu/internal/cmd"
|
||||
"github.com/dagu-org/dagu/internal/common/config"
|
||||
"github.com/dagu-org/dagu/internal/common/stringutil"
|
||||
"github.com/dagu-org/dagu/internal/core"
|
||||
"github.com/dagu-org/dagu/internal/core/execution"
|
||||
"github.com/dagu-org/dagu/internal/core/spec"
|
||||
@ -137,6 +139,304 @@ steps:
|
||||
require.Less(t, duration, 20*time.Second, "took too long: %v", duration)
|
||||
}
|
||||
|
||||
// TestGlobalQueueMaxConcurrency verifies that a global queue with maxConcurrency > 1
|
||||
// processes multiple DAGs concurrently, and that the DAG's maxActiveRuns doesn't
|
||||
// override the global queue's maxConcurrency setting.
|
||||
func TestGlobalQueueMaxConcurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Configure a global queue with maxConcurrency = 3 BEFORE setup
|
||||
// so it gets written to the config file
|
||||
th := test.SetupCommand(t, test.WithConfigMutator(func(cfg *config.Config) {
|
||||
cfg.Queues.Enabled = true
|
||||
cfg.Queues.Config = []config.QueueConfig{
|
||||
{Name: "global-queue", MaxActiveRuns: 3},
|
||||
}
|
||||
}))
|
||||
|
||||
// Create a DAG with maxActiveRuns = 1 that uses the global queue
|
||||
// Each DAG sleeps for 1 second to ensure we can detect concurrent vs sequential execution
|
||||
dagContent := `name: concurrent-test
|
||||
queue: global-queue
|
||||
maxActiveRuns: 1
|
||||
steps:
|
||||
- name: sleep-step
|
||||
command: sleep 1
|
||||
`
|
||||
require.NoError(t, os.MkdirAll(th.Config.Paths.DAGsDir, 0755))
|
||||
dagFile := filepath.Join(th.Config.Paths.DAGsDir, "concurrent-test.yaml")
|
||||
require.NoError(t, os.WriteFile(dagFile, []byte(dagContent), 0644))
|
||||
|
||||
dag, err := spec.Load(th.Context, dagFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Enqueue 3 items
|
||||
runIDs := make([]string, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
dagRunID := uuid.New().String()
|
||||
runIDs[i] = dagRunID
|
||||
|
||||
att, err := th.DAGRunStore.CreateAttempt(th.Context, dag, time.Now(), dagRunID, execution.NewDAGRunAttemptOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
logFile := filepath.Join(th.Config.Paths.LogDir, dag.Name, dagRunID+".log")
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(logFile), 0755))
|
||||
|
||||
dagStatus := transform.NewStatusBuilder(dag).Create(dagRunID, core.Queued, 0, time.Time{},
|
||||
transform.WithLogFilePath(logFile),
|
||||
transform.WithAttemptID(att.ID()),
|
||||
transform.WithHierarchyRefs(
|
||||
execution.NewDAGRunRef(dag.Name, dagRunID),
|
||||
execution.DAGRunRef{},
|
||||
),
|
||||
)
|
||||
|
||||
require.NoError(t, att.Open(th.Context))
|
||||
require.NoError(t, att.Write(th.Context, dagStatus))
|
||||
require.NoError(t, att.Close(th.Context))
|
||||
|
||||
err = th.QueueStore.Enqueue(th.Context, "global-queue", execution.QueuePriorityLow, execution.NewDAGRunRef(dag.Name, dagRunID))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify queue has 3 items
|
||||
queuedItems, err := th.QueueStore.List(th.Context, "global-queue")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queuedItems, 3)
|
||||
|
||||
// Start scheduler
|
||||
schedulerDone := make(chan error, 1)
|
||||
daguHome := filepath.Dir(th.Config.Paths.DAGsDir)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(th.Context, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
thCopy := th
|
||||
thCopy.Context = ctx
|
||||
|
||||
schedulerDone <- thCopy.RunCommandWithError(t, cmd.Scheduler(), test.CmdTest{
|
||||
Args: []string{
|
||||
"scheduler",
|
||||
"--dagu-home", daguHome,
|
||||
},
|
||||
ExpectedOut: []string{"Scheduler started"},
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait until all DAGs complete
|
||||
require.Eventually(t, func() bool {
|
||||
remaining, err := th.QueueStore.List(th.Context, "global-queue")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
t.Logf("Queue: %d/3 items remaining", len(remaining))
|
||||
return len(remaining) == 0
|
||||
}, 15*time.Second, 200*time.Millisecond, "Queue items should be processed")
|
||||
|
||||
th.Cancel()
|
||||
|
||||
select {
|
||||
case <-schedulerDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
|
||||
// Collect start times from all DAG runs
|
||||
var startTimes []time.Time
|
||||
for _, runID := range runIDs {
|
||||
attempt, err := th.DAGRunStore.FindAttempt(th.Context, execution.NewDAGRunRef(dag.Name, runID))
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := attempt.ReadStatus(th.Context)
|
||||
require.NoError(t, err)
|
||||
|
||||
startedAt, err := stringutil.ParseTime(status.StartedAt)
|
||||
require.NoError(t, err, "Failed to parse start time for run %s", runID)
|
||||
require.False(t, startedAt.IsZero(), "Start time is zero for run %s", runID)
|
||||
startTimes = append(startTimes, startedAt)
|
||||
}
|
||||
|
||||
// All 3 DAGs should have started
|
||||
require.Len(t, startTimes, 3, "All 3 DAGs should have started")
|
||||
|
||||
// Find the max difference between start times
|
||||
// If they ran concurrently, all should start within ~500ms
|
||||
// If they ran sequentially (maxConcurrency=1), they'd be ~1s apart
|
||||
var maxDiff time.Duration
|
||||
for i := 0; i < len(startTimes); i++ {
|
||||
for j := i + 1; j < len(startTimes); j++ {
|
||||
diff := startTimes[i].Sub(startTimes[j]).Abs()
|
||||
if diff > maxDiff {
|
||||
maxDiff = diff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Start times: %v", startTimes)
|
||||
t.Logf("Max difference between start times: %v", maxDiff)
|
||||
|
||||
// All DAGs should start within 2 seconds of each other (concurrent execution)
|
||||
// If maxConcurrency was incorrectly set to 1, they would start ~3+ seconds apart
|
||||
// (due to 1 second sleep in each DAG + processing overhead)
|
||||
require.Less(t, maxDiff, 2*time.Second,
|
||||
"All 3 DAGs should start concurrently (within 2s), but max diff was %v", maxDiff)
|
||||
}
|
||||
|
||||
// TestDAGQueueMaxActiveRunsFirstBatch verifies that when a DAG-based (non-global)
|
||||
// queue is first encountered, all items up to maxActiveRuns are processed in the
|
||||
// first batch, not just 1.
|
||||
//
|
||||
// This test covers the bug where dynamically created queues were initialized with
|
||||
// maxConcurrency=1, causing only 1 DAG to start initially even when maxActiveRuns > 1.
|
||||
// The fix reads the DAG's maxActiveRuns before selecting items for the first batch.
|
||||
func TestDAGQueueMaxActiveRunsFirstBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
th := test.SetupCommand(t, test.WithConfigMutator(func(cfg *config.Config) {
|
||||
cfg.Queues.Enabled = true
|
||||
// No global queues configured - we want to test DAG-based queues
|
||||
}))
|
||||
|
||||
// Create a DAG with maxActiveRuns = 3 (no queue: field, so it uses DAG-based queue)
|
||||
// Each DAG sleeps for 2 seconds to ensure we can detect concurrent vs sequential execution
|
||||
dagContent := `name: dag-queue-test
|
||||
maxActiveRuns: 3
|
||||
steps:
|
||||
- name: sleep-step
|
||||
command: sleep 2
|
||||
`
|
||||
require.NoError(t, os.MkdirAll(th.Config.Paths.DAGsDir, 0755))
|
||||
dagFile := filepath.Join(th.Config.Paths.DAGsDir, "dag-queue-test.yaml")
|
||||
require.NoError(t, os.WriteFile(dagFile, []byte(dagContent), 0644))
|
||||
|
||||
dag, err := spec.Load(th.Context, dagFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify queue name is the DAG name (DAG-based queue)
|
||||
queueName := dag.ProcGroup()
|
||||
require.Equal(t, "dag-queue-test", queueName, "DAG should use its name as queue")
|
||||
|
||||
// Enqueue 3 items
|
||||
runIDs := make([]string, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
dagRunID := uuid.New().String()
|
||||
runIDs[i] = dagRunID
|
||||
|
||||
att, err := th.DAGRunStore.CreateAttempt(th.Context, dag, time.Now(), dagRunID, execution.NewDAGRunAttemptOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
logFile := filepath.Join(th.Config.Paths.LogDir, dag.Name, dagRunID+".log")
|
||||
require.NoError(t, os.MkdirAll(filepath.Dir(logFile), 0755))
|
||||
|
||||
dagStatus := transform.NewStatusBuilder(dag).Create(dagRunID, core.Queued, 0, time.Time{},
|
||||
transform.WithLogFilePath(logFile),
|
||||
transform.WithAttemptID(att.ID()),
|
||||
transform.WithHierarchyRefs(
|
||||
execution.NewDAGRunRef(dag.Name, dagRunID),
|
||||
execution.DAGRunRef{},
|
||||
),
|
||||
)
|
||||
|
||||
require.NoError(t, att.Open(th.Context))
|
||||
require.NoError(t, att.Write(th.Context, dagStatus))
|
||||
require.NoError(t, att.Close(th.Context))
|
||||
|
||||
err = th.QueueStore.Enqueue(th.Context, queueName, execution.QueuePriorityLow, execution.NewDAGRunRef(dag.Name, dagRunID))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify queue has 3 items
|
||||
queuedItems, err := th.QueueStore.List(th.Context, queueName)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, queuedItems, 3)
|
||||
t.Logf("Enqueued 3 items to DAG-based queue %q", queueName)
|
||||
|
||||
// Start scheduler
|
||||
schedulerDone := make(chan error, 1)
|
||||
daguHome := filepath.Dir(th.Config.Paths.DAGsDir)
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(th.Context, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
thCopy := th
|
||||
thCopy.Context = ctx
|
||||
|
||||
schedulerDone <- thCopy.RunCommandWithError(t, cmd.Scheduler(), test.CmdTest{
|
||||
Args: []string{
|
||||
"scheduler",
|
||||
"--dagu-home", daguHome,
|
||||
},
|
||||
ExpectedOut: []string{"Scheduler started"},
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait until all DAGs complete
|
||||
startTime := time.Now()
|
||||
require.Eventually(t, func() bool {
|
||||
remaining, err := th.QueueStore.List(th.Context, queueName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
t.Logf("Queue: %d/3 items remaining", len(remaining))
|
||||
return len(remaining) == 0
|
||||
}, 20*time.Second, 200*time.Millisecond, "Queue items should be processed")
|
||||
|
||||
totalDuration := time.Since(startTime)
|
||||
t.Logf("All items processed in %v", totalDuration)
|
||||
|
||||
th.Cancel()
|
||||
|
||||
select {
|
||||
case <-schedulerDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
|
||||
// Collect start times from all DAG runs
|
||||
var startTimes []time.Time
|
||||
for _, runID := range runIDs {
|
||||
attempt, err := th.DAGRunStore.FindAttempt(th.Context, execution.NewDAGRunRef(dag.Name, runID))
|
||||
require.NoError(t, err)
|
||||
|
||||
status, err := attempt.ReadStatus(th.Context)
|
||||
require.NoError(t, err)
|
||||
|
||||
startedAt, err := stringutil.ParseTime(status.StartedAt)
|
||||
require.NoError(t, err, "Failed to parse start time for run %s", runID)
|
||||
require.False(t, startedAt.IsZero(), "Start time is zero for run %s", runID)
|
||||
startTimes = append(startTimes, startedAt)
|
||||
}
|
||||
|
||||
// All 3 DAGs should have started
|
||||
require.Len(t, startTimes, 3, "All 3 DAGs should have started")
|
||||
|
||||
// Find the max difference between start times
|
||||
var maxDiff time.Duration
|
||||
for i := 0; i < len(startTimes); i++ {
|
||||
for j := i + 1; j < len(startTimes); j++ {
|
||||
diff := startTimes[i].Sub(startTimes[j]).Abs()
|
||||
if diff > maxDiff {
|
||||
maxDiff = diff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("Start times: %v", startTimes)
|
||||
t.Logf("Max difference between start times: %v", maxDiff)
|
||||
|
||||
// KEY ASSERTION: All 3 DAGs should start in the FIRST batch (concurrently)
|
||||
// If the bug exists (queue initialized with maxConcurrency=1), they would start
|
||||
// sequentially: first at 0s, second at ~2s, third at ~4s (total ~6s+)
|
||||
// With the fix, all 3 start within the first batch, so max diff <= 2s
|
||||
require.LessOrEqual(t, maxDiff, 2*time.Second,
|
||||
"All 3 DAGs should start in first batch (within 2s), but max diff was %v. "+
|
||||
"This suggests maxActiveRuns was not applied to the first batch.", maxDiff)
|
||||
|
||||
// Also verify total time is reasonable
|
||||
// - Concurrent execution: ~2s sleep + scheduler overhead (~4-6s total)
|
||||
// - Sequential execution (bug): ~6s sleep + overhead (~8-10s total)
|
||||
require.Less(t, totalDuration, 8*time.Second,
|
||||
"Total processing time should be under 8s for concurrent execution, but was %v", totalDuration)
|
||||
}
|
||||
|
||||
// TestCronScheduleRunsTwice verifies that a DAG with */1 * * * * schedule
|
||||
// runs twice in two minutes.
|
||||
func TestCronScheduleRunsTwice(t *testing.T) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user