yunlian-file-v2/server/server.go
2021-05-18 12:50:32 +08:00

131 lines
2.5 KiB
Go

/**
* Created by IntelliJ IDEA.
* User: Ar.M
* Date: 2020-05-10
* Time: 22:13
*/
package main
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"google.golang.org/grpc"
"io"
"log"
"net"
"os"
"path/filepath"
transportpb "yunlian-file-v2/proto"
)
var filesDir string
func init() {
str, _ := os.Getwd()
filesDir = filepath.Join(str, "/files")
}
type server struct{}
func (s *server) GetFileInfo(ctx context.Context, req *transportpb.FileInfoRequest) (*transportpb.FileInfoResponse, error) {
fileName := req.GetFileName()
path := filepath.Join(filesDir, fileName)
file, err := os.Open(path)
defer file.Close()
if err != nil {
return nil, err
}
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
fmt.Println(err)
}
fileSha256 := hex.EncodeToString(hash.Sum(nil))
fileInfo, err := file.Stat()
if err != nil {
return nil, err
}
resp := &transportpb.FileInfoResponse{
FileName: fileName,
FileSize: fileInfo.Size(),
FileSha256: fileSha256,
}
return resp, nil
}
func (s *server) Download(req *transportpb.FileRequest, stream transportpb.FileService_DownloadServer) error {
fileName := req.GetFileName()
fileRangeStart := req.GetFileRangeStart()
fileRangeEnd := req.GetFileRangeEnd()
path := filepath.Join(filesDir, fileName)
fileInfo, err := os.Stat(path)
if err != nil {
return err
}
fileSize := fileInfo.Size()
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
_, err = f.Seek(fileRangeStart, 0)
if err != nil {
return err
}
log.Printf("开始读取区块: %d-%d\n", fileRangeStart, fileRangeEnd)
var totalBytesStreamed int64
var packageLen int64 = 1024
for totalBytesStreamed < fileRangeEnd-fileRangeStart {
block := make([]byte, packageLen)
if totalBytesStreamed+packageLen > fileRangeEnd-fileRangeStart {
block = make([]byte, fileRangeEnd-fileRangeStart-totalBytesStreamed)
}
bytesRead, err := f.Read(block)
if err == io.EOF {
break
}
if err != nil {
return err
}
if err := stream.Send(&transportpb.FileResponse{
Block: block,
FileSize: fileSize,
}); err != nil {
return err
}
totalBytesStreamed += int64(bytesRead)
}
return nil
}
func main() {
lis, err := net.Listen("tcp", "0.0.0.0:10086")
if err != nil {
log.Fatalf("Failed to listen on 10086 : %v\n", err)
}
s := grpc.NewServer()
transportpb.RegisterFileServiceServer(s, &server{})
fmt.Println("Starting server on 10086")
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to start server : %v\n", err)
}
}