package main import ( "context" "crypto/md5" "encoding/hex" "errors" "flag" "fmt" "github.com/dustin/go-humanize" "google.golang.org/grpc" "io" "io/ioutil" "log" "os" "path/filepath" "strconv" "strings" "sync" "time" transportpb "yunlian-file-v2/proto" ) var group sync.WaitGroup var pi chan string = make(chan string, 2) var exit chan bool = make(chan bool, 1) var exitSub chan bool = make(chan bool, 1) var downloadContext *fileContext = new(fileContext) var thread int64 var root string var attempt int var fn string var ip string var port int type fileInfo struct { filePath string fileName string length int64 savePath string } type block struct { previous *block id string start int64 end int64 count int } type fileContext struct { file *fileInfo fileNames map[string]*block tempList []string } func init() { flag.Int64Var(&thread, "thread", 1, "下载线程数") flag.IntVar(&attempt, "attempt", 10, "下载重试次数") flag.StringVar(&root, "root", "./tempdir", "临时文件目录") flag.StringVar(&fn, "fn", "", "要下载的文件名") flag.StringVar(&ip, "ip", "127.0.0.1", "服务端IP") flag.IntVar(&port, "port", 10086, "服务端端口号") flag.Usage = func() { fmt.Println("yunlian-file-v2 version: 2.0.1") flag.PrintDefaults() } } func main() { flag.Parse() if len(fn) < 1 { flag.Usage() return } serverIp := fmt.Sprintf("%s:%d", ip, port) clientConn, err := grpc.Dial(serverIp, grpc.WithInsecure()) if err != nil { log.Fatalf("监听服务端异常 : %v\n", err) } defer clientConn.Close() log.Printf("连接到 %s ...", serverIp) fsClient := transportpb.NewFileServiceClient(clientConn) Download(fn, fsClient) } func Download(filePath string, client transportpb.FileServiceClient) (err error) { fileLength := distribFile(filePath, client) if fileLength < 1 { return errors.New("获取文件信息失败") } go cleanCache() log.Printf("正在下载: %s (%s)", filePath, humanize.IBytes(uint64(fileLength))) previous := time.Now() for key, meta := range downloadContext.fileNames { if checkBlockStat(key, meta) { continue } log.Printf("正在下载:%s,区块:%d-%d", key, meta.start, meta.end) group.Add(1) go startDownloadTask(downloadContext.file.filePath, key, meta, client) } processBar(downloadContext.file.length, previous) group.Wait() log.Println("区块下载完成,正在合并文件...") err = createFileOnly(downloadContext.file.filePath) if err != nil { log.Println(err.Error()) panic(err) } for i := len(downloadContext.tempList) - 1; i >= 0; i-- { err = appendToFile(downloadContext.file.filePath, readFile(downloadContext.tempList[i])) if err != nil { log.Println(err.Error(), "下载失败,请重试") return } if i == 0 { exit <- true } } flag := <-exit if flag { log.Println("合并完成,正在清除临时文件...") for _, file := range downloadContext.tempList { deleteFile(file) } log.Println("下载完成") return } log.Println("下载失败,请重试") return } func startDownloadTask(filePath string, tempFilePath string, b *block, client transportpb.FileServiceClient) { existSize := getFileSize(tempFilePath) req := &transportpb.FileRequest{ FileName: filePath, FileRangeStart: b.start + existSize, FileRangeEnd: b.end, } stream, err := client.Download(context.Background(), req) if err != nil { log.Fatalf("下载异常 : %v\n", err) } err = createFileOnly(tempFilePath) if err != nil { log.Println(err.Error()) panic(err) } var recvSize int64 = 0 for { res, err := stream.Recv() if err == io.EOF { break } block := res.GetBlock() blockSize := len(block) recvSize += int64(blockSize) if blockSize != 0 { err := appendToFile(tempFilePath, block) if err != nil { log.Fatalf("临时文件保存异常: %s\n", err) } } } if err != nil || recvSize != (b.end-(b.start+existSize)) { log.Println("下载重试中") if b.count > attempt { pi <- b.id err = nil } if b.count <= attempt { b.count++ startDownloadTask(filePath, tempFilePath, b, client) } } if err == nil { group.Done() } } func cleanCache() { for { select { case str := <-pi: p := filePath(str) deleteFile(p) exit <- false exitSub <- false case <-exitSub: break } } } func processBar(length int64, t time.Time) { for { var sum int64 = 0 for key, _ := range downloadContext.fileNames { sum += getFileSize(key) } percent := getPercent(sum, length) result, _ := strconv.Atoi(percent) str := percent + "%" + "[" + bar(result, 100) + "] " + " " + fmt.Sprintf("%.f", getCurrentSize(t)) + "s" fmt.Printf("\r%s", str) time.Sleep(100 * time.Millisecond) if sum == length { fmt.Println("") break } } } func bar(count, size int) string { str := "" for i := 0; i < size; i++ { if i < count { str += "=" } else { str += " " } } return str } func getPercent(a int64, b int64) string { result := float64(a) / float64(b) * 100 return fmt.Sprintf("%.f", result) } func getCurrentSize(t time.Time) float64 { return time.Now().Sub(t).Seconds() } func distribFile(fPath string, client transportpb.FileServiceClient) int64 { fileNameNotExt, fileName := parseName(fPath) infoReq := &transportpb.FileInfoRequest{ FileName: fileName, } fileInfoResp, err := client.GetFileInfo(context.Background(), infoReq) if err != nil { log.Fatalf("获取文件信息异常: %v\n", err) } length := fileInfoResp.FileSize if !checkFileStat(root) { err := os.MkdirAll(root, 0777) if err != nil { panic(err) } } downloadContext.file = &fileInfo{filePath: fPath, fileName: fileName, length: length, savePath: filePath(fileName)} blocks := chunkFile(length, thread, fileNameNotExt) distribBlock(blocks) return length } func parseName(fPath string) (tmpName, fullName string) { u := []byte(fPath) s := strings.LastIndex(fPath, "/") if s == -1 { s = 0 fullName = string(u[s:]) } else { fullName = string(u[s+1:]) } t := []byte(fullName) d := strings.LastIndex(fullName, ".") if d == -1 { d = len(t) tmpName = string(t[:]) } else { tmpName = string(t[:d]) } return } func distribBlock(b *block) { if b == nil { return } m := make(map[string]*block) listId := []string{} p := filePath(b.id) m[p] = b listId = append(listId, p) for b.previous != nil { b = b.previous p = filePath(b.id) m[p] = b listId = append(listId, p) } downloadContext.fileNames = m downloadContext.tempList = listId } func chunkFile(length int64, thread int64, name string) (b *block) { blockSize := length / thread surplus := length % thread b = nil var start int64 var i int64 if surplus == 0 { for i = 1; i <= thread; i++ { seg := new(block) r := name + MD5(strconv.FormatInt(i, 10)) seg.id = r seg.previous = b seg.start = start seg.end = blockSize * i start = blockSize * i b = seg } } else { for i = 1; i <= thread; i++ { seg := new(block) r := name + MD5(strconv.FormatInt(i, 10)) seg.id = r seg.previous = b seg.start = start if i == thread { seg.end = blockSize*i + surplus } else { seg.end = blockSize * i } start = blockSize * i b = seg } } return b } // 生成32位MD5 func MD5(text string) string { ctx := md5.New() ctx.Write([]byte(text)) return hex.EncodeToString(ctx.Sum(nil)) } func createFileOnly(file string) error { if checkFileStat(file) { deleteFile(file) } f, err := os.Create(file) if err != nil { log.Println(file, "文件创建失败") } defer f.Close() return err } func deleteFile(file string) error { if !checkFileStat(file) { return nil } err := os.Remove(file) if err != nil { log.Println(file, "文件删除失败") } return err } func checkFileStat(file string) bool { var exist = true if _, err := os.Stat(file); os.IsNotExist(err) { exist = false } return exist } func checkBlockStat(filePath string, b *block) bool { m := checkFileStat(filePath) if m { if int64(len(readFile(filePath))) == (b.end - b.start) { return true } else { // deleteFile(filePath) return false } } return false } func appendToFile(fileName string, content []byte) error { // 以只写的模式,打开文件 f, err := os.OpenFile(fileName, os.O_WRONLY, 0644) if err != nil { log.Println("file append failed. err: " + err.Error()) } else { // 查找文件末尾的偏移量 n, _ := f.Seek(0, os.SEEK_END) // 从末尾的偏移量开始写入内容 _, err = f.WriteAt(content, n) } defer f.Close() return err } func readFile(path string) []byte { fi, err := os.Open(path) if err != nil { panic(err) } defer fi.Close() fd, _ := ioutil.ReadAll(fi) return fd } func filePath(id string) string { var file string file = filepath.Join(root, id) return file } func getFileSize(file string) int64 { if !checkFileStat(file) { return 0 } fi, _ := os.Stat(file) return fi.Size() }