Home | History | Annotate | Download | only in prog
      1 // Copyright 2015 syzkaller project authors. All rights reserved.
      2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
      3 
      4 package prog
      5 
      6 import (
      7 	"bytes"
      8 	"encoding/hex"
      9 	"fmt"
     10 	"math/rand"
     11 	"strings"
     12 	"testing"
     13 	"time"
     14 )
     15 
     16 func TestGeneration(t *testing.T) {
     17 	target, rs, iters := initTest(t)
     18 	for i := 0; i < iters; i++ {
     19 		target.Generate(rs, 20, nil)
     20 	}
     21 }
     22 
     23 func TestDefault(t *testing.T) {
     24 	target, _, _ := initTest(t)
     25 	for _, meta := range target.Syscalls {
     26 		ForeachType(meta, func(typ Type) {
     27 			arg := typ.makeDefaultArg()
     28 			if !isDefault(arg) {
     29 				t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v",
     30 					typ, typ, arg)
     31 			}
     32 		})
     33 	}
     34 }
     35 
     36 func TestDefaultCallArgs(t *testing.T) {
     37 	target, _, _ := initTest(t)
     38 	for _, meta := range target.SyscallMap {
     39 		// Ensure that we can restore all arguments of all calls.
     40 		prog := fmt.Sprintf("%v()", meta.Name)
     41 		p, err := target.Deserialize([]byte(prog))
     42 		if err != nil {
     43 			t.Fatalf("failed to restore default args in prog %q: %v", prog, err)
     44 		}
     45 		if len(p.Calls) != 1 || p.Calls[0].Meta.Name != meta.Name {
     46 			t.Fatalf("restored bad program from prog %q: %q", prog, p.Serialize())
     47 		}
     48 	}
     49 }
     50 
     51 func TestSerialize(t *testing.T) {
     52 	target, rs, iters := initTest(t)
     53 	for i := 0; i < iters; i++ {
     54 		p := target.Generate(rs, 10, nil)
     55 		data := p.Serialize()
     56 		p1, err := target.Deserialize(data)
     57 		if err != nil {
     58 			t.Fatalf("failed to deserialize program: %v\n%s", err, data)
     59 		}
     60 		if p1 == nil {
     61 			t.Fatalf("deserialized nil program:\n%s", data)
     62 		}
     63 		data1 := p1.Serialize()
     64 		if len(p.Calls) != len(p1.Calls) {
     65 			t.Fatalf("different number of calls")
     66 		}
     67 		if !bytes.Equal(data, data1) {
     68 			t.Fatalf("program changed after serialize/deserialize\noriginal:\n%s\n\nnew:\n%s\n", data, data1)
     69 		}
     70 	}
     71 }
     72 
     73 func TestVmaType(t *testing.T) {
     74 	target, rs, iters := initRandomTargetTest(t, "test", "64")
     75 	meta := target.SyscallMap["test$vma0"]
     76 	r := newRand(target, rs)
     77 	pageSize := target.PageSize
     78 	for i := 0; i < iters; i++ {
     79 		s := newState(target, nil)
     80 		calls := r.generateParticularCall(s, meta)
     81 		c := calls[len(calls)-1]
     82 		if c.Meta.Name != "test$vma0" {
     83 			t.Fatalf("generated wrong call %v", c.Meta.Name)
     84 		}
     85 		if len(c.Args) != 6 {
     86 			t.Fatalf("generated wrong number of args %v", len(c.Args))
     87 		}
     88 		check := func(v, l Arg, min, max uint64) {
     89 			va, ok := v.(*PointerArg)
     90 			if !ok {
     91 				t.Fatalf("vma has bad type: %v", v)
     92 			}
     93 			la, ok := l.(*ConstArg)
     94 			if !ok {
     95 				t.Fatalf("len has bad type: %v", l)
     96 			}
     97 			if va.VmaSize < min || va.VmaSize > max {
     98 				t.Fatalf("vma has bad size: %v, want [%v-%v]",
     99 					va.VmaSize, min, max)
    100 			}
    101 			if la.Val < min || la.Val > max {
    102 				t.Fatalf("len has bad value: %v, want [%v-%v]",
    103 					la.Val, min, max)
    104 			}
    105 		}
    106 		check(c.Args[0], c.Args[1], 1*pageSize, 1e5*pageSize)
    107 		check(c.Args[2], c.Args[3], 5*pageSize, 5*pageSize)
    108 		check(c.Args[4], c.Args[5], 7*pageSize, 9*pageSize)
    109 	}
    110 }
    111 
    112 // TestCrossTarget ensures that a program serialized for one arch can be
    113 // deserialized for another arch. This happens when managers exchange
    114 // programs via hub.
    115 func TestCrossTarget(t *testing.T) {
    116 	t.Parallel()
    117 	const OS = "linux"
    118 	var archs []string
    119 	for _, target := range AllTargets() {
    120 		if target.OS == OS {
    121 			archs = append(archs, target.Arch)
    122 		}
    123 	}
    124 	for _, arch := range archs {
    125 		target, err := GetTarget(OS, arch)
    126 		if err != nil {
    127 			t.Fatal(err)
    128 		}
    129 		var crossTargets []*Target
    130 		for _, crossArch := range archs {
    131 			if crossArch == arch {
    132 				continue
    133 			}
    134 			crossTarget, err := GetTarget(OS, crossArch)
    135 			if err != nil {
    136 				t.Fatal(err)
    137 			}
    138 			crossTargets = append(crossTargets, crossTarget)
    139 		}
    140 		t.Run(fmt.Sprintf("%v/%v", OS, arch), func(t *testing.T) {
    141 			t.Parallel()
    142 			testCrossTarget(t, target, crossTargets)
    143 		})
    144 	}
    145 }
    146 
    147 func testCrossTarget(t *testing.T, target *Target, crossTargets []*Target) {
    148 	seed := int64(time.Now().UnixNano())
    149 	t.Logf("seed=%v", seed)
    150 	rs := rand.NewSource(seed)
    151 	iters := 100
    152 	if testing.Short() {
    153 		iters /= 10
    154 	}
    155 	for i := 0; i < iters; i++ {
    156 		p := target.Generate(rs, 20, nil)
    157 		testCrossArchProg(t, p, crossTargets)
    158 		p, err := target.Deserialize(p.Serialize())
    159 		if err != nil {
    160 			t.Fatal(err)
    161 		}
    162 		testCrossArchProg(t, p, crossTargets)
    163 		p.Mutate(rs, 20, nil, nil)
    164 		testCrossArchProg(t, p, crossTargets)
    165 		p, _ = Minimize(p, -1, false, func(*Prog, int) bool {
    166 			return rs.Int63()%2 == 0
    167 		})
    168 		testCrossArchProg(t, p, crossTargets)
    169 	}
    170 }
    171 
    172 func testCrossArchProg(t *testing.T, p *Prog, crossTargets []*Target) {
    173 	serialized := p.Serialize()
    174 	for _, crossTarget := range crossTargets {
    175 		_, err := crossTarget.Deserialize(serialized)
    176 		if err == nil || strings.Contains(err.Error(), "unknown syscall") {
    177 			continue
    178 		}
    179 		t.Fatalf("failed to deserialize for %v/%v: %v\n%s",
    180 			crossTarget.OS, crossTarget.Arch, err, serialized)
    181 	}
    182 }
    183 
    184 func TestSpecialStructs(t *testing.T) {
    185 	testEachTargetRandom(t, func(t *testing.T, target *Target, rs rand.Source, iters int) {
    186 		for special, gen := range target.SpecialTypes {
    187 			t.Run(special, func(t *testing.T) {
    188 				var typ Type
    189 				for i := 0; i < len(target.Syscalls) && typ == nil; i++ {
    190 					ForeachType(target.Syscalls[i], func(t Type) {
    191 						if t.Dir() == DirOut {
    192 							return
    193 						}
    194 						if s, ok := t.(*StructType); ok && s.Name() == special {
    195 							typ = s
    196 						}
    197 						if s, ok := t.(*UnionType); ok && s.Name() == special {
    198 							typ = s
    199 						}
    200 					})
    201 				}
    202 				if typ == nil {
    203 					t.Fatal("can't find struct description")
    204 				}
    205 				g := &Gen{newRand(target, rs), newState(target, nil)}
    206 				for i := 0; i < iters/len(target.SpecialTypes); i++ {
    207 					arg, _ := gen(g, typ, nil)
    208 					gen(g, typ, arg)
    209 				}
    210 			})
    211 		}
    212 	})
    213 }
    214 
    215 func TestEscapingPaths(t *testing.T) {
    216 	paths := map[string]bool{
    217 		"/":                      true,
    218 		"/\x00":                  true,
    219 		"/file/..":               true,
    220 		"/file/../..":            true,
    221 		"./..":                   true,
    222 		"..":                     true,
    223 		"file/../../file":        true,
    224 		"../file":                true,
    225 		"./file/../../file/file": true,
    226 		"":          false,
    227 		".":         false,
    228 		"file":      false,
    229 		"./file":    false,
    230 		"./file/..": false,
    231 	}
    232 	target, err := GetTarget("test", "64")
    233 	if err != nil {
    234 		t.Fatal(err)
    235 	}
    236 	for path, escaping := range paths {
    237 		text := fmt.Sprintf("mutate5(&(0x7f0000000000)=\"%s\", 0x0)", hex.EncodeToString([]byte(path)))
    238 		_, err := target.Deserialize([]byte(text))
    239 		if !escaping && err != nil {
    240 			t.Errorf("path %q is detected as escaping (%v)", path, err)
    241 		}
    242 		if escaping && (err == nil || !strings.Contains(err.Error(), "sandbox escaping file")) {
    243 			t.Errorf("path %q is not detected as escaping (%v)", path, err)
    244 		}
    245 	}
    246 }
    247