// run -goexperiment rangefunc

// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Test the 'for range' construct ranging over functions.

package main

var gj int

func yield4x(yield func() bool) {
	_ = yield() && yield() && yield() && yield()
}

func yield4(yield func(int) bool) {
	_ = yield(1) && yield(2) && yield(3) && yield(4)
}

func yield3(yield func(int) bool) {
	_ = yield(1) && yield(2) && yield(3)
}

func yield2(yield func(int) bool) {
	_ = yield(1) && yield(2)
}

func testfunc0() {
	j := 0
	for range yield4x {
		j++
	}
	if j != 4 {
		println("wrong count ranging over yield4x:", j)
		panic("testfunc0")
	}

	j = 0
	for _ = range yield4 {
		j++
	}
	if j != 4 {
		println("wrong count ranging over yield4:", j)
		panic("testfunc0")
	}
}

func testfunc1() {
	bad := false
	j := 1
	for i := range yield4 {
		if i != j {
			println("range var", i, "want", j)
			bad = true
		}
		j++
	}
	if j != 5 {
		println("wrong count ranging over f:", j)
		bad = true
	}
	if bad {
		panic("testfunc1")
	}
}

func testfunc2() {
	bad := false
	j := 1
	var i int
	for i = range yield4 {
		if i != j {
			println("range var", i, "want", j)
			bad = true
		}
		j++
	}
	if j != 5 {
		println("wrong count ranging over f:", j)
		bad = true
	}
	if i != 4 {
		println("wrong final i ranging over f:", i)
		bad = true
	}
	if bad {
		panic("testfunc2")
	}
}

func testfunc3() {
	bad := false
	j := 1
	var i int
	for i = range yield4 {
		if i != j {
			println("range var", i, "want", j)
			bad = true
		}
		j++
		if i == 2 {
			break
		}
		continue
	}
	if j != 3 {
		println("wrong count ranging over f:", j)
		bad = true
	}
	if i != 2 {
		println("wrong final i ranging over f:", i)
		bad = true
	}
	if bad {
		panic("testfunc3")
	}
}

func testfunc4() {
	bad := false
	j := 1
	var i int
	func() {
		for i = range yield4 {
			if i != j {
				println("range var", i, "want", j)
				bad = true
			}
			j++
			if i == 2 {
				return
			}
		}
	}()
	if j != 3 {
		println("wrong count ranging over f:", j)
		bad = true
	}
	if i != 2 {
		println("wrong final i ranging over f:", i)
		bad = true
	}
	if bad {
		panic("testfunc3")
	}
}

func func5() (int, int) {
	for i := range yield4 {
		return 10, i
	}
	panic("still here")
}

func testfunc5() {
	x, y := func5()
	if x != 10 || y != 1 {
		println("wrong results", x, y, "want", 10, 1)
		panic("testfunc5")
	}
}

func func6() (z, w int) {
	for i := range yield4 {
		z = 10
		w = i
		return
	}
	panic("still here")
}

func testfunc6() {
	x, y := func6()
	if x != 10 || y != 1 {
		println("wrong results", x, y, "want", 10, 1)
		panic("testfunc6")
	}
}

var saved []int

func save(x int) {
	saved = append(saved, x)
}

func printslice(s []int) {
	print("[")
	for i, x := range s {
		if i > 0 {
			print(", ")
		}
		print(x)
	}
	print("]")
}

func eqslice(s, t []int) bool {
	if len(s) != len(t) {
		return false
	}
	for i, x := range s {
		if x != t[i] {
			return false
		}
	}
	return true
}

func func7() {
	defer save(-1)
	for i := range yield4 {
		defer save(i)
	}
	defer save(5)
}

func checkslice(name string, saved, want []int) {
	if !eqslice(saved, want) {
		print("wrong results ")
		printslice(saved)
		print(" want ")
		printslice(want)
		print("\n")
		panic(name)
	}
}

func testfunc7() {
	saved = nil
	func7()
	want := []int{5, 4, 3, 2, 1, -1}
	checkslice("testfunc7", saved, want)
}

func func8() {
	defer save(-1)
	for i := range yield2 {
		for j := range yield3 {
			defer save(i*10 + j)
		}
		defer save(i)
	}
	defer save(-2)
	for i := range yield4 {
		defer save(i)
	}
	defer save(-3)
}

func testfunc8() {
	saved = nil
	func8()
	want := []int{-3, 4, 3, 2, 1, -2, 2, 23, 22, 21, 1, 13, 12, 11, -1}
	checkslice("testfunc8", saved, want)
}

func func9() {
	n := 0
	for _ = range yield2 {
		for _ = range yield3 {
			n++
			defer save(n)
		}
	}
}

func testfunc9() {
	saved = nil
	func9()
	want := []int{6, 5, 4, 3, 2, 1}
	checkslice("testfunc9", saved, want)
}

// test that range evaluates the index and value expressions
// exactly once per iteration.

var ncalls = 0

func getvar(p *int) *int {
	ncalls++
	return p
}

func iter2(list ...int) func(func(int, int) bool) {
	return func(yield func(int, int) bool) {
		for i, x := range list {
			if !yield(i, x) {
				return
			}
		}
	}
}

func testcalls() {
	var i, v int
	ncalls = 0
	si := 0
	sv := 0
	for *getvar(&i), *getvar(&v) = range iter2(1, 2) {
		si += i
		sv += v
	}
	if ncalls != 4 {
		println("wrong number of calls:", ncalls, "!= 4")
		panic("fail")
	}
	if si != 1 || sv != 3 {
		println("wrong sum in testcalls", si, sv)
		panic("fail")
	}
}

type iter3YieldFunc func(int, int) bool

func iter3(list ...int) func(iter3YieldFunc) {
	return func(yield iter3YieldFunc) {
		for k, v := range list {
			if !yield(k, v) {
				return
			}
		}
	}
}

func testcalls1() {
	ncalls := 0
	for k, v := range iter3(1, 2, 3) {
		_, _ = k, v
		ncalls++
	}
	if ncalls != 3 {
		println("wrong number of calls:", ncalls, "!= 3")
		panic("fail")
	}
}

func main() {
	testfunc0()
	testfunc1()
	testfunc2()
	testfunc3()
	testfunc4()
	testfunc5()
	testfunc6()
	testfunc7()
	testfunc8()
	testfunc9()
	testcalls()
	testcalls1()
}