451 lines
8.8 KiB
Go
451 lines
8.8 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/md5"
|
||
"encoding/hex"
|
||
"errors"
|
||
"flag"
|
||
"fmt"
|
||
"github.com/dustin/go-humanize"
|
||
"google.golang.org/grpc"
|
||
"io"
|
||
"io/ioutil"
|
||
"log"
|
||
"os"
|
||
"path/filepath"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
transportpb "yunlian-file-v2/proto"
|
||
)
|
||
|
||
var group sync.WaitGroup
|
||
var pi chan string = make(chan string, 2)
|
||
var exit chan bool = make(chan bool, 1)
|
||
var exitSub chan bool = make(chan bool, 1)
|
||
var downloadContext *fileContext = new(fileContext)
|
||
|
||
var thread int64
|
||
var root string
|
||
var attempt int
|
||
var fn string
|
||
var ip string
|
||
var port int
|
||
|
||
type fileInfo struct {
|
||
filePath string
|
||
fileName string
|
||
length int64
|
||
savePath string
|
||
}
|
||
type block struct {
|
||
previous *block
|
||
id string
|
||
start int64
|
||
end int64
|
||
count int
|
||
}
|
||
type fileContext struct {
|
||
file *fileInfo
|
||
fileNames map[string]*block
|
||
tempList []string
|
||
}
|
||
|
||
func init() {
|
||
flag.Int64Var(&thread, "thread", 1, "下载线程数")
|
||
flag.IntVar(&attempt, "attempt", 10, "下载重试次数")
|
||
flag.StringVar(&root, "root", "./tempdir", "临时文件目录")
|
||
flag.StringVar(&fn, "fn", "", "要下载的文件名")
|
||
flag.StringVar(&ip, "ip", "127.0.0.1", "服务端IP")
|
||
flag.IntVar(&port, "port", 10086, "服务端端口号")
|
||
|
||
flag.Usage = func() {
|
||
fmt.Println("yunlian-file-v2 version: 2.0.1")
|
||
flag.PrintDefaults()
|
||
}
|
||
}
|
||
|
||
func main() {
|
||
flag.Parse()
|
||
|
||
if len(fn) < 1 {
|
||
flag.Usage()
|
||
return
|
||
}
|
||
|
||
serverIp := fmt.Sprintf("%s:%d", ip, port)
|
||
|
||
clientConn, err := grpc.Dial(serverIp, grpc.WithInsecure())
|
||
if err != nil {
|
||
log.Fatalf("监听服务端异常 : %v\n", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
log.Printf("连接到 %s ...", serverIp)
|
||
|
||
fsClient := transportpb.NewFileServiceClient(clientConn)
|
||
|
||
Download(fn, fsClient)
|
||
}
|
||
|
||
func Download(filePath string, client transportpb.FileServiceClient) (err error) {
|
||
fileLength := distribFile(filePath, client)
|
||
|
||
if fileLength < 1 {
|
||
return errors.New("获取文件信息失败")
|
||
}
|
||
|
||
go cleanCache()
|
||
|
||
log.Printf("正在下载: %s (%s)", filePath, humanize.IBytes(uint64(fileLength)))
|
||
|
||
previous := time.Now()
|
||
|
||
for key, meta := range downloadContext.fileNames {
|
||
if checkBlockStat(key, meta) {
|
||
continue
|
||
}
|
||
|
||
log.Printf("正在下载:%s,区块:%d-%d", key, meta.start, meta.end)
|
||
group.Add(1)
|
||
go startDownloadTask(downloadContext.file.filePath, key, meta, client)
|
||
}
|
||
processBar(downloadContext.file.length, previous)
|
||
group.Wait()
|
||
|
||
log.Println("区块下载完成,正在合并文件...")
|
||
|
||
err = createFileOnly(downloadContext.file.filePath)
|
||
if err != nil {
|
||
log.Println(err.Error())
|
||
panic(err)
|
||
}
|
||
|
||
for i := len(downloadContext.tempList) - 1; i >= 0; i-- {
|
||
err = appendToFile(downloadContext.file.filePath, readFile(downloadContext.tempList[i]))
|
||
if err != nil {
|
||
log.Println(err.Error(), "下载失败,请重试")
|
||
return
|
||
}
|
||
if i == 0 {
|
||
exit <- true
|
||
}
|
||
}
|
||
|
||
flag := <-exit
|
||
if flag {
|
||
log.Println("合并完成,正在清除临时文件...")
|
||
for _, file := range downloadContext.tempList {
|
||
deleteFile(file)
|
||
}
|
||
log.Println("下载完成")
|
||
return
|
||
}
|
||
log.Println("下载失败,请重试")
|
||
return
|
||
}
|
||
|
||
func startDownloadTask(filePath string, tempFilePath string, b *block, client transportpb.FileServiceClient) {
|
||
existSize := getFileSize(tempFilePath)
|
||
|
||
req := &transportpb.FileRequest{
|
||
FileName: filePath,
|
||
FileRangeStart: b.start + existSize,
|
||
FileRangeEnd: b.end,
|
||
}
|
||
|
||
stream, err := client.Download(context.Background(), req)
|
||
if err != nil {
|
||
log.Fatalf("下载异常 : %v\n", err)
|
||
}
|
||
|
||
err = createFileOnly(tempFilePath)
|
||
if err != nil {
|
||
log.Println(err.Error())
|
||
panic(err)
|
||
}
|
||
|
||
var recvSize int64 = 0
|
||
|
||
for {
|
||
res, err := stream.Recv()
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
|
||
block := res.GetBlock()
|
||
blockSize := len(block)
|
||
recvSize += int64(blockSize)
|
||
|
||
if blockSize != 0 {
|
||
err := appendToFile(tempFilePath, block)
|
||
if err != nil {
|
||
log.Fatalf("临时文件保存异常: %s\n", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
if err != nil || recvSize != (b.end-(b.start+existSize)) {
|
||
log.Println("下载重试中")
|
||
if b.count > attempt {
|
||
pi <- b.id
|
||
err = nil
|
||
}
|
||
if b.count <= attempt {
|
||
b.count++
|
||
startDownloadTask(filePath, tempFilePath, b, client)
|
||
}
|
||
}
|
||
|
||
if err == nil {
|
||
group.Done()
|
||
}
|
||
}
|
||
|
||
func cleanCache() {
|
||
for {
|
||
select {
|
||
case str := <-pi:
|
||
p := filePath(str)
|
||
deleteFile(p)
|
||
exit <- false
|
||
exitSub <- false
|
||
case <-exitSub:
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
func processBar(length int64, t time.Time) {
|
||
for {
|
||
var sum int64 = 0
|
||
for key, _ := range downloadContext.fileNames {
|
||
sum += getFileSize(key)
|
||
}
|
||
percent := getPercent(sum, length)
|
||
result, _ := strconv.Atoi(percent)
|
||
str := percent + "%" + "[" + bar(result, 100) + "] " + " " + fmt.Sprintf("%.f", getCurrentSize(t)) + "s"
|
||
fmt.Printf("\r%s", str)
|
||
time.Sleep(100 * time.Millisecond)
|
||
if sum == length {
|
||
fmt.Println("")
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
func bar(count, size int) string {
|
||
str := ""
|
||
for i := 0; i < size; i++ {
|
||
if i < count {
|
||
str += "="
|
||
} else {
|
||
str += " "
|
||
}
|
||
}
|
||
return str
|
||
}
|
||
|
||
func getPercent(a int64, b int64) string {
|
||
result := float64(a) / float64(b) * 100
|
||
return fmt.Sprintf("%.f", result)
|
||
}
|
||
|
||
func getCurrentSize(t time.Time) float64 {
|
||
return time.Now().Sub(t).Seconds()
|
||
}
|
||
|
||
func distribFile(fPath string, client transportpb.FileServiceClient) int64 {
|
||
fileNameNotExt, fileName := parseName(fPath)
|
||
|
||
infoReq := &transportpb.FileInfoRequest{
|
||
FileName: fileName,
|
||
}
|
||
fileInfoResp, err := client.GetFileInfo(context.Background(), infoReq)
|
||
if err != nil {
|
||
log.Fatalf("获取文件信息异常: %v\n", err)
|
||
}
|
||
|
||
length := fileInfoResp.FileSize
|
||
|
||
if !checkFileStat(root) {
|
||
err := os.MkdirAll(root, 0777)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
}
|
||
|
||
downloadContext.file = &fileInfo{filePath: fPath, fileName: fileName, length: length, savePath: filePath(fileName)}
|
||
blocks := chunkFile(length, thread, fileNameNotExt)
|
||
distribBlock(blocks)
|
||
return length
|
||
|
||
}
|
||
|
||
func parseName(fPath string) (tmpName, fullName string) {
|
||
u := []byte(fPath)
|
||
s := strings.LastIndex(fPath, "/")
|
||
if s == -1 {
|
||
s = 0
|
||
fullName = string(u[s:])
|
||
} else {
|
||
fullName = string(u[s+1:])
|
||
}
|
||
t := []byte(fullName)
|
||
d := strings.LastIndex(fullName, ".")
|
||
if d == -1 {
|
||
d = len(t)
|
||
tmpName = string(t[:])
|
||
} else {
|
||
tmpName = string(t[:d])
|
||
}
|
||
return
|
||
}
|
||
|
||
func distribBlock(b *block) {
|
||
if b == nil {
|
||
return
|
||
}
|
||
m := make(map[string]*block)
|
||
listId := []string{}
|
||
p := filePath(b.id)
|
||
m[p] = b
|
||
listId = append(listId, p)
|
||
for b.previous != nil {
|
||
b = b.previous
|
||
p = filePath(b.id)
|
||
m[p] = b
|
||
listId = append(listId, p)
|
||
}
|
||
downloadContext.fileNames = m
|
||
downloadContext.tempList = listId
|
||
}
|
||
|
||
func chunkFile(length int64, thread int64, name string) (b *block) {
|
||
blockSize := length / thread
|
||
surplus := length % thread
|
||
b = nil
|
||
var start int64
|
||
var i int64
|
||
if surplus == 0 {
|
||
for i = 1; i <= thread; i++ {
|
||
seg := new(block)
|
||
r := name + MD5(strconv.FormatInt(i, 10))
|
||
seg.id = r
|
||
seg.previous = b
|
||
seg.start = start
|
||
seg.end = blockSize * i
|
||
start = blockSize * i
|
||
b = seg
|
||
}
|
||
} else {
|
||
for i = 1; i <= thread; i++ {
|
||
seg := new(block)
|
||
r := name + MD5(strconv.FormatInt(i, 10))
|
||
seg.id = r
|
||
seg.previous = b
|
||
seg.start = start
|
||
if i == thread {
|
||
seg.end = blockSize*i + surplus
|
||
} else {
|
||
seg.end = blockSize * i
|
||
}
|
||
start = blockSize * i
|
||
b = seg
|
||
}
|
||
}
|
||
return b
|
||
}
|
||
|
||
// 生成32位MD5
|
||
func MD5(text string) string {
|
||
ctx := md5.New()
|
||
ctx.Write([]byte(text))
|
||
return hex.EncodeToString(ctx.Sum(nil))
|
||
}
|
||
|
||
func createFileOnly(file string) error {
|
||
if checkFileStat(file) {
|
||
deleteFile(file)
|
||
}
|
||
f, err := os.Create(file)
|
||
if err != nil {
|
||
log.Println(file, "文件创建失败")
|
||
}
|
||
defer f.Close()
|
||
return err
|
||
}
|
||
|
||
func deleteFile(file string) error {
|
||
if !checkFileStat(file) {
|
||
return nil
|
||
}
|
||
err := os.Remove(file)
|
||
if err != nil {
|
||
log.Println(file, "文件删除失败")
|
||
}
|
||
return err
|
||
}
|
||
|
||
func checkFileStat(file string) bool {
|
||
var exist = true
|
||
if _, err := os.Stat(file); os.IsNotExist(err) {
|
||
exist = false
|
||
}
|
||
return exist
|
||
}
|
||
|
||
func checkBlockStat(filePath string, b *block) bool {
|
||
m := checkFileStat(filePath)
|
||
if m {
|
||
if int64(len(readFile(filePath))) == (b.end - b.start) {
|
||
return true
|
||
} else {
|
||
// deleteFile(filePath)
|
||
return false
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
func appendToFile(fileName string, content []byte) error {
|
||
// 以只写的模式,打开文件
|
||
f, err := os.OpenFile(fileName, os.O_WRONLY, 0644)
|
||
if err != nil {
|
||
log.Println("file append failed. err: " + err.Error())
|
||
} else {
|
||
// 查找文件末尾的偏移量
|
||
n, _ := f.Seek(0, os.SEEK_END)
|
||
// 从末尾的偏移量开始写入内容
|
||
_, err = f.WriteAt(content, n)
|
||
}
|
||
defer f.Close()
|
||
return err
|
||
}
|
||
|
||
func readFile(path string) []byte {
|
||
fi, err := os.Open(path)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
defer fi.Close()
|
||
fd, _ := ioutil.ReadAll(fi)
|
||
return fd
|
||
}
|
||
|
||
func filePath(id string) string {
|
||
var file string
|
||
file = filepath.Join(root, id)
|
||
return file
|
||
}
|
||
|
||
func getFileSize(file string) int64 {
|
||
if !checkFileStat(file) {
|
||
return 0
|
||
}
|
||
fi, _ := os.Stat(file)
|
||
return fi.Size()
|
||
}
|