Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .golangci.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
version: "2"
linters:
default: all
default: none
enable:
- revive
formatters:
enable:
- gci
Expand Down
10 changes: 4 additions & 6 deletions binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@ import (
set "github.com/deckarep/golang-set/v2"
)

var (
// ErrDuplicateBinding is returned when Binder.Store is called with a binding whose ID has already
// been stored (which implies that the graph being executed contains multiple tasks producing
// bindings for the same Key).
ErrDuplicateBinding = errors.New("duplicate binding")
)
// ErrDuplicateBinding is returned when Binder.Store is called with a binding whose ID has already
// been stored (which implies that the graph being executed contains multiple tasks producing
// bindings for the same Key).
var ErrDuplicateBinding = errors.New("duplicate binding")

// BindStatus represents the tristate of a Binding.
type BindStatus int
Expand Down
25 changes: 20 additions & 5 deletions binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
set "github.com/deckarep/golang-set/v2"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"

tg "github.com/thought-machine/taskgraph"
tgt "github.com/thought-machine/taskgraph/taskgraphtest"
)
Expand Down Expand Up @@ -45,7 +44,11 @@ func TestBindersBindingsAndKeys(t *testing.T) {
}

if err := b.Store(key2.Bind(456)); !errors.Is(err, tg.ErrDuplicateBinding) {
t.Errorf("Expected b.Store(key2.Bind(456)) to return error %v; got %v", tg.ErrDuplicateBinding, err)
t.Errorf(
"Expected b.Store(key2.Bind(456)) to return error %v; got %v",
tg.ErrDuplicateBinding,
err,
)
}

tgt.DiffPresent[string](t, b, key1, "foo")
Expand Down Expand Up @@ -144,7 +147,11 @@ func TestOverlayBinder(t *testing.T) {
tgt.ExpectPending[int](t, ob, key2)

if err := ob.Store(key1.Bind(123)); !errors.Is(err, tg.ErrDuplicateBinding) {
t.Errorf("Expected ob.Store(key1.Bind(123)) to return error %v; got %v", tg.ErrDuplicateBinding, err)
t.Errorf(
"Expected ob.Store(key1.Bind(123)) to return error %v; got %v",
tg.ErrDuplicateBinding,
err,
)
}

if err := ob.Store(key2.Bind(456)); err != nil {
Expand Down Expand Up @@ -194,10 +201,18 @@ func TestGraphTaskBinder(t *testing.T) {
})

if err := gtb.Store(key1.Bind(456)); !errors.Is(err, tg.ErrDuplicateBinding) {
t.Errorf("Expected gtb.Store(key1.Bind(456)) to return error %v; got %v", tg.ErrDuplicateBinding, err)
t.Errorf(
"Expected gtb.Store(key1.Bind(456)) to return error %v; got %v",
tg.ErrDuplicateBinding,
err,
)
}

if err := gtb.Store(key2.Bind(123)); !errors.Is(err, tg.ErrDuplicateBinding) {
t.Errorf("Expected gtb.Store(key2.Bind(123)) to return error %v; got %v", tg.ErrDuplicateBinding, err)
t.Errorf(
"Expected gtb.Store(key2.Bind(123)) to return error %v; got %v",
tg.ErrDuplicateBinding,
err,
)
}
}
4 changes: 3 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ var taskIsPalindrome = tg.Reflect[bool]{
Depends: []any{keyInput, keyReversed},
}.Locate()

var graphIsPalindrome = tg.Must(tg.New("example_graph", tg.WithTasks(taskReverseInput, taskIsPalindrome)))
var graphIsPalindrome = tg.Must(
tg.New("example_graph", tg.WithTasks(taskReverseInput, taskIsPalindrome)),
)

func Example() {
res, err := graphIsPalindrome.Run(context.Background(), keyInput.Bind("racecar"))
Expand Down
77 changes: 59 additions & 18 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ func (gn *graphNode) execute(ctx context.Context, rs *runState) (err error) {
"task %s: mismatch between task Provides declaration and returned bindings: missing bindings [%s], got extra bindings [%s]",
gn.task.Name(),
strings.Join(missing, ", "),
strings.Join(extra, ", "))
strings.Join(extra, ", "),
)
}

for _, dependent := range gn.dependents {
Expand Down Expand Up @@ -263,7 +264,8 @@ func (g *graph) Run(ctx context.Context, inputs ...Binding) (b Binder, err error
if err != nil {
result = "error"
}
executionLatency.WithLabelValues(g.name, result).Observe(float64(time.Since(startTime) / time.Millisecond))
executionLatency.WithLabelValues(g.name, result).
Observe(float64(time.Since(startTime) / time.Millisecond))
}()
base, err := g.buildInputBinder(inputs...)
if err != nil {
Expand Down Expand Up @@ -352,7 +354,10 @@ func (g *graph) AsTask(exposeKeys ...ID) (Task, error) {
}
}
if len(missing) > 0 {
return nil, wrapStackErrorf("exposed key(s) not bound after graph execution: %s", strings.Join(missing, ", "))
return nil, wrapStackErrorf(
"exposed key(s) not bound after graph execution: %s",
strings.Join(missing, ", "),
)
}

// The exposed keys are added to the external binder via the graphTaskBinder, so we don't return
Expand All @@ -371,7 +376,10 @@ func (g *graph) Graphviz(includeInputs bool) string {
for _, dep := range n.task.Depends() {
if !g.allProvided.Contains(dep) {
inputID := fmt.Sprintf("%s_input_%s", n.id, dep.id)
nodes = append(nodes, fmt.Sprintf(" %s [label=\"Input - %s\", shape=diamond];", inputID, dep))
nodes = append(
nodes,
fmt.Sprintf(" %s [label=\"Input - %s\", shape=diamond];", inputID, dep),
)
edges = append(edges, fmt.Sprintf(" %s -> %s;", inputID, n.id))
}
}
Expand All @@ -384,8 +392,14 @@ func (g *graph) Graphviz(includeInputs bool) string {
for _, dep := range n.task.Provides() {
if !g.allDependencies.Contains(dep) {
outputID := fmt.Sprintf("%s_output_%s", n.id, dep)
nodes = append(nodes, fmt.Sprintf(" %s [label=\"Output\", shape=diamond];", outputID))
edges = append(edges, fmt.Sprintf(" %s -> %s [label=\"%s\"];", n.id, outputID, dep))
nodes = append(
nodes,
fmt.Sprintf(" %s [label=\"Output\", shape=diamond];", outputID),
)
edges = append(
edges,
fmt.Sprintf(" %s -> %s [label=\"%s\"];", n.id, outputID, dep),
)
}
}
}
Expand All @@ -402,15 +416,17 @@ func (g *graph) Graphviz(includeInputs bool) string {
return buf.String()
}

type GraphOptions struct {
type graphOptions struct {
tasks []Task
tracer trace.Tracer
}

type GraphOption func(opts *GraphOptions) error
// A GraphOption is used to configure a new Graph.
type GraphOption func(opts *graphOptions) error

// WithTasks sets the tasks which form the graph.
func WithTasks(tasks ...TaskSet) GraphOption {
return func(opts *GraphOptions) error {
return func(opts *graphOptions) error {
opts.tasks = taskset(tasks).Tasks()

if len(opts.tasks) > taskLimit {
Expand All @@ -421,8 +437,9 @@ func WithTasks(tasks ...TaskSet) GraphOption {
}
}

// WithTracer sets a tracer to record graph execution.
func WithTracer(tracer trace.Tracer) GraphOption {
return func(opts *GraphOptions) error {
return func(opts *graphOptions) error {
opts.tracer = tracer

return nil
Expand All @@ -433,7 +450,7 @@ func WithTracer(tracer trace.Tracer) GraphOption {
//
// Ideally, Graphs should be created on program startup, rather than creating them dynamically.
func New(name string, opts ...GraphOption) (Graph, error) {
o := &GraphOptions{
o := &graphOptions{
tracer: noop.NewTracerProvider().Tracer("github.com/thought-machine/taskgraph"),
}

Expand All @@ -458,7 +475,10 @@ func New(name string, opts ...GraphOption) (Graph, error) {
var badTaskErrs error
for _, t := range g.tasks {
if t.Name() == "" || t.Location() == "" {
badTaskErrs = errors.Join(badTaskErrs, fmt.Errorf("tasks must have a name and location: (%s, %s)", t.Name(), t.Location()))
badTaskErrs = errors.Join(
badTaskErrs,
fmt.Errorf("tasks must have a name and location: (%s, %s)", t.Name(), t.Location()),
)
}
node := &graphNode{
id: sanitizeTaskName(t.Name()),
Expand All @@ -477,7 +497,10 @@ func New(name string, opts ...GraphOption) (Graph, error) {

g.allProvided.Append(t.Provides()...)
for _, id := range t.Provides() {
provideTasks[id.String()] = append(provideTasks[id.String()], fmt.Sprintf("%s - %s", t.Name(), t.Location()))
provideTasks[id.String()] = append(
provideTasks[id.String()],
fmt.Sprintf("%s - %s", t.Name(), t.Location()),
)
}
}
if badTaskErrs != nil {
Expand All @@ -486,20 +509,34 @@ func New(name string, opts ...GraphOption) (Graph, error) {
var duplicateTaskNames []string
for name, locations := range taskLocations {
if len(locations) > 1 {
duplicateTaskNames = append(duplicateTaskNames, fmt.Sprintf("%s (%s)", name, strings.Join(locations, ", ")))
duplicateTaskNames = append(
duplicateTaskNames,
fmt.Sprintf("%s (%s)", name, strings.Join(locations, ", ")),
)
}
}
if len(duplicateTaskNames) > 0 {
return nil, wrapStackErrorf("%w: %s", ErrDuplicateTaskNames, strings.Join(duplicateTaskNames, ", "))
return nil, wrapStackErrorf(
"%w: %s",
ErrDuplicateTaskNames,
strings.Join(duplicateTaskNames, ", "),
)
}
var duplicateProvides []string
for id, tasks := range provideTasks {
if len(tasks) > 1 {
duplicateProvides = append(duplicateProvides, fmt.Sprintf("%s (%s)", id, strings.Join(tasks, ", ")))
duplicateProvides = append(
duplicateProvides,
fmt.Sprintf("%s (%s)", id, strings.Join(tasks, ", ")),
)
}
}
if len(duplicateProvides) > 0 {
return nil, wrapStackErrorf("%w: %s", ErrDuplicateProvidedKeys, strings.Join(duplicateProvides, ", "))
return nil, wrapStackErrorf(
"%w: %s",
ErrDuplicateProvidedKeys,
strings.Join(duplicateProvides, ", "),
)
}

for _, node := range g.nodes {
Expand Down Expand Up @@ -543,7 +580,11 @@ func sanitizeTaskName(name string) string {
func checkCycle(node *graphNode, path []string) error {
for i := len(path) - 1; i >= 0; i-- {
if path[i] == node.task.Name() {
return wrapStackErrorf("%w: %s", ErrGraphCycle, strings.Join(append(path[i:], path[i]), " -> "))
return wrapStackErrorf(
"%w: %s",
ErrGraphCycle,
strings.Join(append(path[i:], path[i]), " -> "),
)
}
}
path = append(path, node.task.Name())
Expand Down
9 changes: 7 additions & 2 deletions graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"testing"

"github.com/google/go-cmp/cmp"

tg "github.com/thought-machine/taskgraph"
tgt "github.com/thought-machine/taskgraph/taskgraphtest"
)
Expand Down Expand Up @@ -172,7 +171,13 @@ func TestGraphErrors(t *testing.T) {
for i := 0; i <= 1000; i++ {
tasks = append(tasks, tg.NewTask("task", tgt.DummyTaskFunc(), nil, nil))
}
if _, err := tg.New("test_graph", tg.WithTasks(tasks...)); !errors.Is(err, tg.ErrTooManyTasks) {
if _, err := tg.New(
"test_graph",
tg.WithTasks(tasks...),
); !errors.Is(
err,
tg.ErrTooManyTasks,
) {
t.Errorf("expected error %v; got %v", tg.ErrTooManyTasks, err)
}
})
Expand Down
8 changes: 7 additions & 1 deletion key.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ func (k *key[T]) Get(b Binder) (T, error) {
typed, ok := binding.Value().(T)
if !ok {
var want T
return empty, wrapStackErrorf("cannot get key %q: %w (got %T, want %T)", k.id, ErrWrongType, binding.Value(), want)
return empty, wrapStackErrorf(
"cannot get key %q: %w (got %T, want %T)",
k.id,
ErrWrongType,
binding.Value(),
want,
)
}
return typed, nil
default:
Expand Down
41 changes: 34 additions & 7 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ func (rk *reflectKey) Get(b Binder) (reflect.Value, error) {
if !outs[1].IsNil() {
err, ok := outs[1].Interface().(error)
if !ok {
return reflect.Value{}, wrapStackErrorf("could not convert output 1 to error; got %T", outs[1].Interface())
return reflect.Value{}, wrapStackErrorf(
"could not convert output 1 to error; got %T",
outs[1].Interface(),
)
}
return reflect.Value{}, err
}
Expand Down Expand Up @@ -163,14 +166,21 @@ func newReflectFn(fn any, resultType reflect.Type, deps ...any) (rf *reflectFn,
if !outs[1].IsNil() {
err, ok := outs[1].Interface().(error)
if !ok {
return nil, wrapStackErrorf("could not convert function output 1 to error; got %T", outs[1].Interface())
return nil, wrapStackErrorf(
"could not convert function output 1 to error; got %T",
outs[1].Interface(),
)
}
return nil, err
}
return outs[0].Interface(), nil
}
} else {
return nil, wrapStackErrorf("function should return %s or (%s, error)", resultType, resultType)
return nil, wrapStackErrorf(
"function should return %s or (%s, error)",
resultType,
resultType,
)
}

hasContext := fnType.NumIn() > 0 && fnType.In(0).Implements(contextType)
Expand All @@ -181,7 +191,11 @@ func newReflectFn(fn any, resultType reflect.Type, deps ...any) (rf *reflectFn,
offset++
}
if argCount != len(deps) {
return nil, wrapStackErrorf("function takes %d arguments (excluding any initial context), but %d deps were provided", argCount, len(deps))
return nil, wrapStackErrorf(
"function takes %d arguments (excluding any initial context), but %d deps were provided",
argCount,
len(deps),
)
}

var keys []*reflectKey
Expand All @@ -192,7 +206,12 @@ func newReflectFn(fn any, resultType reflect.Type, deps ...any) (rf *reflectFn,
return nil, wrapStackErrorf("dependency %d: %w", i, err)
}
if !rk.valueType.AssignableTo(fnType.In(i + offset)) {
return nil, wrapStackErrorf("dependency %d is Key[%v]; want Key[%v]", i, rk.valueType, fnType.In(i+offset))
return nil, wrapStackErrorf(
"dependency %d is Key[%v]; want Key[%v]",
i,
rk.valueType,
fnType.In(i+offset),
)
}
keys = append(keys, rk)
id, err := rk.ID()
Expand Down Expand Up @@ -272,7 +291,11 @@ func (r Reflect[T]) Build() (Task, error) {
}
typed, ok := res.(T)
if !ok {
return nil, wrapStackErrorf("%s: could not convert function result to T; got %T", r.errorPrefix(), res)
return nil, wrapStackErrorf(
"%s: could not convert function result to T; got %T",
r.errorPrefix(),
res,
)
}
return []Binding{r.ResultKey.Bind(typed)}, nil
},
Expand Down Expand Up @@ -341,7 +364,11 @@ func (r ReflectMulti) Build() (Task, error) {
}
typed, ok := res.([]Binding)
if !ok {
return nil, wrapStackErrorf("%s: could not convert function result to []Binding; got %T", r.errorPrefix(), res)
return nil, wrapStackErrorf(
"%s: could not convert function result to []Binding; got %T",
r.errorPrefix(),
res,
)
}
return typed, nil
},
Expand Down
Loading