Home | History | Annotate | Download | only in registry
      1 // Copyright 2015 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // +build windows
      6 
      7 package registry
      8 
      9 import (
     10 	"errors"
     11 	"io"
     12 	"syscall"
     13 	"unicode/utf16"
     14 	"unsafe"
     15 )
     16 
     17 const (
     18 	// Registry value types.
     19 	NONE                       = 0
     20 	SZ                         = 1
     21 	EXPAND_SZ                  = 2
     22 	BINARY                     = 3
     23 	DWORD                      = 4
     24 	DWORD_BIG_ENDIAN           = 5
     25 	LINK                       = 6
     26 	MULTI_SZ                   = 7
     27 	RESOURCE_LIST              = 8
     28 	FULL_RESOURCE_DESCRIPTOR   = 9
     29 	RESOURCE_REQUIREMENTS_LIST = 10
     30 	QWORD                      = 11
     31 )
     32 
     33 var (
     34 	// ErrShortBuffer is returned when the buffer was too short for the operation.
     35 	ErrShortBuffer = syscall.ERROR_MORE_DATA
     36 
     37 	// ErrNotExist is returned when a registry key or value does not exist.
     38 	ErrNotExist = syscall.ERROR_FILE_NOT_FOUND
     39 
     40 	// ErrUnexpectedType is returned by Get*Value when the value's type was unexpected.
     41 	ErrUnexpectedType = errors.New("unexpected key value type")
     42 )
     43 
     44 // GetValue retrieves the type and data for the specified value associated
     45 // with an open key k. It fills up buffer buf and returns the retrieved
     46 // byte count n. If buf is too small to fit the stored value it returns
     47 // ErrShortBuffer error along with the required buffer size n.
     48 // If no buffer is provided, it returns true and actual buffer size n.
     49 // If no buffer is provided, GetValue returns the value's type only.
     50 // If the value does not exist, the error returned is ErrNotExist.
     51 //
     52 // GetValue is a low level function. If value's type is known, use the appropriate
     53 // Get*Value function instead.
     54 func (k Key) GetValue(name string, buf []byte) (n int, valtype uint32, err error) {
     55 	pname, err := syscall.UTF16PtrFromString(name)
     56 	if err != nil {
     57 		return 0, 0, err
     58 	}
     59 	var pbuf *byte
     60 	if len(buf) > 0 {
     61 		pbuf = (*byte)(unsafe.Pointer(&buf[0]))
     62 	}
     63 	l := uint32(len(buf))
     64 	err = syscall.RegQueryValueEx(syscall.Handle(k), pname, nil, &valtype, pbuf, &l)
     65 	if err != nil {
     66 		return int(l), valtype, err
     67 	}
     68 	return int(l), valtype, nil
     69 }
     70 
     71 func (k Key) getValue(name string, buf []byte) (date []byte, valtype uint32, err error) {
     72 	p, err := syscall.UTF16PtrFromString(name)
     73 	if err != nil {
     74 		return nil, 0, err
     75 	}
     76 	var t uint32
     77 	n := uint32(len(buf))
     78 	for {
     79 		err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n)
     80 		if err == nil {
     81 			return buf[:n], t, nil
     82 		}
     83 		if err != syscall.ERROR_MORE_DATA {
     84 			return nil, 0, err
     85 		}
     86 		if n <= uint32(len(buf)) {
     87 			return nil, 0, err
     88 		}
     89 		buf = make([]byte, n)
     90 	}
     91 }
     92 
     93 // GetStringValue retrieves the string value for the specified
     94 // value name associated with an open key k. It also returns the value's type.
     95 // If value does not exist, GetStringValue returns ErrNotExist.
     96 // If value is not SZ or EXPAND_SZ, it will return the correct value
     97 // type and ErrUnexpectedType.
     98 func (k Key) GetStringValue(name string) (val string, valtype uint32, err error) {
     99 	data, typ, err2 := k.getValue(name, make([]byte, 64))
    100 	if err2 != nil {
    101 		return "", typ, err2
    102 	}
    103 	switch typ {
    104 	case SZ, EXPAND_SZ:
    105 	default:
    106 		return "", typ, ErrUnexpectedType
    107 	}
    108 	if len(data) == 0 {
    109 		return "", typ, nil
    110 	}
    111 	u := (*[1 << 10]uint16)(unsafe.Pointer(&data[0]))[:]
    112 	return syscall.UTF16ToString(u), typ, nil
    113 }
    114 
    115 // ExpandString expands environment-variable strings and replaces
    116 // them with the values defined for the current user.
    117 // Use ExpandString to expand EXPAND_SZ strings.
    118 func ExpandString(value string) (string, error) {
    119 	if value == "" {
    120 		return "", nil
    121 	}
    122 	p, err := syscall.UTF16PtrFromString(value)
    123 	if err != nil {
    124 		return "", err
    125 	}
    126 	r := make([]uint16, 100)
    127 	for {
    128 		n, err := expandEnvironmentStrings(p, &r[0], uint32(len(r)))
    129 		if err != nil {
    130 			return "", err
    131 		}
    132 		if n <= uint32(len(r)) {
    133 			u := (*[1 << 15]uint16)(unsafe.Pointer(&r[0]))[:]
    134 			return syscall.UTF16ToString(u), nil
    135 		}
    136 		r = make([]uint16, n)
    137 	}
    138 }
    139 
    140 // GetStringsValue retrieves the []string value for the specified
    141 // value name associated with an open key k. It also returns the value's type.
    142 // If value does not exist, GetStringsValue returns ErrNotExist.
    143 // If value is not MULTI_SZ, it will return the correct value
    144 // type and ErrUnexpectedType.
    145 func (k Key) GetStringsValue(name string) (val []string, valtype uint32, err error) {
    146 	data, typ, err2 := k.getValue(name, make([]byte, 64))
    147 	if err2 != nil {
    148 		return nil, typ, err2
    149 	}
    150 	if typ != MULTI_SZ {
    151 		return nil, typ, ErrUnexpectedType
    152 	}
    153 	if len(data) == 0 {
    154 		return nil, typ, nil
    155 	}
    156 	p := (*[1 << 24]uint16)(unsafe.Pointer(&data[0]))[:len(data)/2]
    157 	if len(p) == 0 {
    158 		return nil, typ, nil
    159 	}
    160 	if p[len(p)-1] == 0 {
    161 		p = p[:len(p)-1] // remove terminating null
    162 	}
    163 	val = make([]string, 0, 5)
    164 	from := 0
    165 	for i, c := range p {
    166 		if c == 0 {
    167 			val = append(val, string(utf16.Decode(p[from:i])))
    168 			from = i + 1
    169 		}
    170 	}
    171 	return val, typ, nil
    172 }
    173 
    174 // GetIntegerValue retrieves the integer value for the specified
    175 // value name associated with an open key k. It also returns the value's type.
    176 // If value does not exist, GetIntegerValue returns ErrNotExist.
    177 // If value is not DWORD or QWORD, it will return the correct value
    178 // type and ErrUnexpectedType.
    179 func (k Key) GetIntegerValue(name string) (val uint64, valtype uint32, err error) {
    180 	data, typ, err2 := k.getValue(name, make([]byte, 8))
    181 	if err2 != nil {
    182 		return 0, typ, err2
    183 	}
    184 	switch typ {
    185 	case DWORD:
    186 		if len(data) != 4 {
    187 			return 0, typ, errors.New("DWORD value is not 4 bytes long")
    188 		}
    189 		return uint64(*(*uint32)(unsafe.Pointer(&data[0]))), DWORD, nil
    190 	case QWORD:
    191 		if len(data) != 8 {
    192 			return 0, typ, errors.New("QWORD value is not 8 bytes long")
    193 		}
    194 		return uint64(*(*uint64)(unsafe.Pointer(&data[0]))), QWORD, nil
    195 	default:
    196 		return 0, typ, ErrUnexpectedType
    197 	}
    198 }
    199 
    200 // GetBinaryValue retrieves the binary value for the specified
    201 // value name associated with an open key k. It also returns the value's type.
    202 // If value does not exist, GetBinaryValue returns ErrNotExist.
    203 // If value is not BINARY, it will return the correct value
    204 // type and ErrUnexpectedType.
    205 func (k Key) GetBinaryValue(name string) (val []byte, valtype uint32, err error) {
    206 	data, typ, err2 := k.getValue(name, make([]byte, 64))
    207 	if err2 != nil {
    208 		return nil, typ, err2
    209 	}
    210 	if typ != BINARY {
    211 		return nil, typ, ErrUnexpectedType
    212 	}
    213 	return data, typ, nil
    214 }
    215 
    216 func (k Key) setValue(name string, valtype uint32, data []byte) error {
    217 	p, err := syscall.UTF16PtrFromString(name)
    218 	if err != nil {
    219 		return err
    220 	}
    221 	if len(data) == 0 {
    222 		return regSetValueEx(syscall.Handle(k), p, 0, valtype, nil, 0)
    223 	}
    224 	return regSetValueEx(syscall.Handle(k), p, 0, valtype, &data[0], uint32(len(data)))
    225 }
    226 
    227 // SetDWordValue sets the data and type of a name value
    228 // under key k to value and DWORD.
    229 func (k Key) SetDWordValue(name string, value uint32) error {
    230 	return k.setValue(name, DWORD, (*[4]byte)(unsafe.Pointer(&value))[:])
    231 }
    232 
    233 // SetQWordValue sets the data and type of a name value
    234 // under key k to value and QWORD.
    235 func (k Key) SetQWordValue(name string, value uint64) error {
    236 	return k.setValue(name, QWORD, (*[8]byte)(unsafe.Pointer(&value))[:])
    237 }
    238 
    239 func (k Key) setStringValue(name string, valtype uint32, value string) error {
    240 	v, err := syscall.UTF16FromString(value)
    241 	if err != nil {
    242 		return err
    243 	}
    244 	buf := (*[1 << 10]byte)(unsafe.Pointer(&v[0]))[:len(v)*2]
    245 	return k.setValue(name, valtype, buf)
    246 }
    247 
    248 // SetStringValue sets the data and type of a name value
    249 // under key k to value and SZ. The value must not contain a zero byte.
    250 func (k Key) SetStringValue(name, value string) error {
    251 	return k.setStringValue(name, SZ, value)
    252 }
    253 
    254 // SetExpandStringValue sets the data and type of a name value
    255 // under key k to value and EXPAND_SZ. The value must not contain a zero byte.
    256 func (k Key) SetExpandStringValue(name, value string) error {
    257 	return k.setStringValue(name, EXPAND_SZ, value)
    258 }
    259 
    260 // SetStringsValue sets the data and type of a name value
    261 // under key k to value and MULTI_SZ. The value strings
    262 // must not contain a zero byte.
    263 func (k Key) SetStringsValue(name string, value []string) error {
    264 	ss := ""
    265 	for _, s := range value {
    266 		for i := 0; i < len(s); i++ {
    267 			if s[i] == 0 {
    268 				return errors.New("string cannot have 0 inside")
    269 			}
    270 		}
    271 		ss += s + "\x00"
    272 	}
    273 	v := utf16.Encode([]rune(ss + "\x00"))
    274 	buf := (*[1 << 10]byte)(unsafe.Pointer(&v[0]))[:len(v)*2]
    275 	return k.setValue(name, MULTI_SZ, buf)
    276 }
    277 
    278 // SetBinaryValue sets the data and type of a name value
    279 // under key k to value and BINARY.
    280 func (k Key) SetBinaryValue(name string, value []byte) error {
    281 	return k.setValue(name, BINARY, value)
    282 }
    283 
    284 // DeleteValue removes a named value from the key k.
    285 func (k Key) DeleteValue(name string) error {
    286 	return regDeleteValue(syscall.Handle(k), syscall.StringToUTF16Ptr(name))
    287 }
    288 
    289 // ReadValueNames returns the value names of key k.
    290 // The parameter n controls the number of returned names,
    291 // analogous to the way os.File.Readdirnames works.
    292 func (k Key) ReadValueNames(n int) ([]string, error) {
    293 	ki, err := k.Stat()
    294 	if err != nil {
    295 		return nil, err
    296 	}
    297 	names := make([]string, 0, ki.ValueCount)
    298 	buf := make([]uint16, ki.MaxValueNameLen+1) // extra room for terminating null character
    299 loopItems:
    300 	for i := uint32(0); ; i++ {
    301 		if n > 0 {
    302 			if len(names) == n {
    303 				return names, nil
    304 			}
    305 		}
    306 		l := uint32(len(buf))
    307 		for {
    308 			err := regEnumValue(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil)
    309 			if err == nil {
    310 				break
    311 			}
    312 			if err == syscall.ERROR_MORE_DATA {
    313 				// Double buffer size and try again.
    314 				l = uint32(2 * len(buf))
    315 				buf = make([]uint16, l)
    316 				continue
    317 			}
    318 			if err == _ERROR_NO_MORE_ITEMS {
    319 				break loopItems
    320 			}
    321 			return names, err
    322 		}
    323 		names = append(names, syscall.UTF16ToString(buf[:l]))
    324 	}
    325 	if n > len(names) {
    326 		return names, io.EOF
    327 	}
    328 	return names, nil
    329 }
    330